mnist数据集手写体识别(mlp实现)(代码片段)

mrzhang3389 mrzhang3389     2023-01-14     723

关键词:

github博客传送门
csdn博客传送门

本章所需知识:

  1. 没有基础的请观看深度学习系列视频
  2. tensorflow
  3. Python基础

    资料下载链接:

  4. 深度学习基础网络模型(mnist手写体识别数据集)

    MNIST数据集手写体识别(MLP实现)

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data  # 导入下载数据集手写体
mnist = input_data.read_data_sets(‘../MNIST_data/‘, one_hot=True)


class MLPNet:  # 创建一个MLPNet类
    def __init__(self):
        self.x = tf.placeholder(dtype=tf.float32, shape=[None, 784], name=‘input_x‘)  # 创建一个tensorflow占位符(稍后传入图片数据),定义数据类型为tf.float32,形状shape为 None为批次 784为数据集撑开的 28*28的手写体图片 name可选参数
        self.y = tf.placeholder(dtype=tf.float32, shape=[None, 10], name=‘input_label‘)  # 创建一个tensorflow占位符(稍后传入图片标签), name可选参数

        self.w1 = tf.Variable(tf.truncated_normal(shape=[784, 100], dtype=tf.float32, stddev=tf.sqrt(1 / 100)))  # 定义全链接 第一层/输入层 变量w/神经元w/特征数w 为截断正态分布 形状shape为[784, 100](100个神经元个数的矩阵), stddev(标准差的值 一般情况为sqrt(1/当前神经元个数))
        self.b1 = tf.Variable(tf.zeros([100], dtype=tf.float32))  # 定义变量偏值b 值为零 形状shape为 [100] 当前层w 的个数, 数据类型 dtype为tf.float32

        self.w2 = tf.Variable(tf.truncated_normal(shape=[100, 10], dtype=tf.float32, stddev=tf.sqrt(1 / 10)))  # 定义全链接 第二层/输出层 变量w/神经元w/特征数w 为截断正态分布 形状为[上一层神经元个数, 当前输出神经元个数] 由于我们是识别十分类问题,手写体 0,1,2,3,4,5,6,7,8,9所以我们选择十个神经元 stddev同上
        self.b2 = tf.Variable(tf.zeros([10], dtype=tf.float32))  # 设置原理同上

    # 前向计算
    def forward(self):
        self.forward_1 = tf.nn.relu(tf.matmul(self.x, self.w1) + self.b1)  # 全链接第一层
        self.forward_2 = tf.nn.relu(tf.matmul(self.forward_1, self.w2) + self.b2)  # 全链接第二层
        self.output = tf.nn.softmax(self.forward_2)  # softmax分类器分类
    
    # 后向计算
    def backward(self):
        self.cost = tf.reduce_mean(tf.square(self.output - self.y))  # 定义均方差损失
        self.opt = tf.train.AdamOptimizer().minimize(self.cost)      # 使用AdamOptimizer优化器 优化 self.cost损失函数

    # 计算识别精度
    def acc(self):
        # 将预测值 output 和 标签值 self.y 进行比较
        self.z = tf.equal(tf.argmax(self.output, 1, name=‘output_max‘), tf.argmax(self.y, 1, name=‘y_max‘))
        # 最后对比较出来的bool值 转换为float32类型后 求均值就可以看到满值为 1的精度显示
        self.accaracy = tf.reduce_mean(tf.cast(self.z, tf.float32))


if __name__ == ‘__main__‘:
    net = MLPNet()  # 启动tensorflow绘图的MLPNet
    net.forward()   # 启动前向计算
    net.backward()  # 启动后向计算
    net.acc()       # 启动精度计算
    init = tf.global_variables_initializer()  # 定义初始化tensorflow所有变量操作
    with tf.Session() as sess:                # 创建一个Session会话
        sess.run(init)                        # 执行init变量内的初始化所有变量的操作
        for i in range(10000):                # 训练10000次
            ax, ay = mnist.train.next_batch(100)  # 从mnist数据集中取数据出来 ax接收图片 ay接收标签
            loss, accaracy, _ = sess.run(fetches=[net.cost, net.accaracy, net.opt], feed_dict=net.x: ax, net.y: ay)  # 将数据喂进神经网络(以字典的方式传入) 接收loss返回值
            if i % 1000 == 0:  # 每训练1000次
                test_ax, test_ay = mnist.test.next_batch(100)  # 则使用测试集对当前网络进行测试
                test_output = sess.run(net.output, feed_dict=net.x: test_ax)  # 将测试数据喂进网络 接收一个output值
                z = tf.equal(tf.argmax(test_output, 1, name=‘output_max‘), tf.argmax(test_ay, 1, name=‘test_y_max‘))  # 对output值和标签y值进行求比较运算
                accaracy2 = sess.run(tf.reduce_mean(tf.cast(z, tf.float32)))  # 求出精度的准确率进行打印
                print(accaracy2)  # 打印当前测试集的精度 

最后附上训练截图:

技术分享图片


mnist数据集手写体识别(cnn实现)(代码片段)

...nsorflowPython基础资料下载链接:深度学习基础网络模型(mnist手写体识别数据集)MNIST数据集手写体识别(CNN实现)importtensorflowastfimporttensorflow.examples.tutorials.mnist.input_dataasinput_data#导入 查看详情

实现手写体mnist数据集的识别任务(代码片段)

    实现手写体mnist数据集的识别任务,共分为三个模块文件,分别是描述网络结构的前向传播过程文件(mnist_forward.py)、描述网络参数优化方法的反向传播过程文件(mnist_backward.py)、验证模型准确率的 测试... 查看详情

基于mnist数据集实现手写数字识别(代码片段)

...课程中,多次用到mnist数据集。mnist数据集是一个数字手写体图片库,但它的存储格式并非常见的图片格式,所有的图片都集中保存在四个扩展名为idx*-ubyte.gz的二进制文件。可以直接从官网进行下载http://yann.lecun.com/exd... 查看详情

mnist数据集手写体识别(seq2seq实现)(代码片段)

...nsorflowPython基础资料下载链接:深度学习基础网络模型(mnist手写体识别数据集)MNIST数据集手写体识别(CNN实现)importtensorflowastfimporttensorflow.examples.tutorials.mnist.input_dataasinput_data#导入 查看详情

pytorch使用mnist数据集实现手写数字识别(代码片段)

使用mnist数据集实现手写数字识别是入门必做吧。这里使用pyTorch框架进行简单神经网络的搭建。首先导入需要的包。1importtorch2importtorch.nnasnn3importtorch.utils.dataasData4importtorchvision 接下来需要下载mnist数据集。我们创建train_data。... 查看详情

全连接神经网络实现识别手写数据集mnist(代码片段)

全连接神经网络实现识别手写数据集MNISTMNIST是一个由美国由美国邮政系统开发的手写数字识别数据集。手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载。总共4个文件,该文件是二进制内容。train-images-idx3-uby... 查看详情

keras实现mnist数据集手写数字识别(代码片段)

一.Tensorflow环境的安装这里我们只讲CPU版本,使用Anaconda进行安装a.首先我们要安装Anaconda链接:https://pan.baidu.com/s/1AxdGi93oN9kXCLdyxOMnRA密码:79ig过程如下:第一步:点击next第二步:IAgree第三步:JustME第四步:自己选择一个恰当位置... 查看详情

java实现bp神经网络mnist手写数字识别(代码片段)

Java实现BP神经网络,内含BP神经网络类,采用MNIST数据集,包含服务器和客户端程序,可在服务器训练后使客户端直接使用训练结果,界面有画板,可以手写数字Java实现BP神经网络MNIST手写数字识别如果需要源码,请在下方评论区... 查看详情

用pytorch实现多层感知机(mlp)(全连接神经网络fc)分类mnist手写数字体的识别(代码片段)

1.导入必备的包1importtorch2importnumpyasnp3fromtorchvision.datasetsimportmnist4fromtorchimportnn5fromtorch.autogradimportVariable6importmatplotlib.pyplotasplt7importtorch.nn.functionalasF8fromtorch.utils.da 查看详情

mnist手写数字识别tensorflow(代码片段)

Mnist手写数字识别Tensorflow任务目标了解mnist数据集搭建和测试模型编辑环境操作系统:Win10python版本:3.6集成开发环境:pycharmtensorflow版本:1.*了解mnist数据集mnist数据集:mnist数据集下载地址??MNIST数据集来自美国国家标准与技术研究所,... 查看详情

基于tensorflow2.x实现bp神经网络,实践mnist手写数字识别(代码片段)

一、MNIST数据集MNIST是一个非常有名的手写数字识别数据集,在很多资料中都会被用作深度学习的入门样例。在Tensorflow2.x中该数据集已被封装在了tf.keras.datasets工具包下,如果没有指定数据集的位置,并先前也没有使... 查看详情

基于tensorflow2.x实现bp神经网络,实践mnist手写数字识别(代码片段)

一、MNIST数据集MNIST是一个非常有名的手写数字识别数据集,在很多资料中都会被用作深度学习的入门样例。在Tensorflow2.x中该数据集已被封装在了tf.keras.datasets工具包下,如果没有指定数据集的位置,并先前也没有使... 查看详情

matlab练习程序(神经网络识别mnist手写数据集)(代码片段)

...该有些地方写的还是不对。这次用神经网络识别mnist手写数据集,主要参考了深度学习工具包的一些代码。mnist数据集训练数据一共有28*28*60000个像素,标签有60000个。测试数据一共有28*28*10000个,标签10000个。这里神经网络输入层... 查看详情

matlab练习程序(神经网络识别mnist手写数据集)(代码片段)

...该有些地方写的还是不对。这次用神经网络识别mnist手写数据集,主要参考了深度学习工具包的一些代码。mnist数据集训练数据一共有28*28*60000个像素,标签有60000个。测试数据一共有28*28*10000个,标签10000个。这里神经网络输入层... 查看详情

三种方法实现mnist手写数字识别(代码片段)

MNIST数据集下载:importtensorflowastffromtensorflow.examples.tutorials.mnistimportinput_datamnist=input_data.read_data_sets("MNIST_data/",one_hot=True)#one_hot独热编码,也叫一位有效编码。在任意时候只有一位为1,其他位都是01使用逻辑回归:importt 查看详情

图像分类基于pytorch搭建lstm实现mnist手写数字体识别(双向lstm,附完整代码和数据集)(代码片段)

写在前面:首先感谢兄弟们的关注和订阅,让我有创作的动力,在创作过程我会尽最大能力,保证作品的质量,如果有问题,可以私信我,让我们携手共进,共创辉煌。在https://blog.csdn.net/AugustMe/article/details/128969138文章中,我们... 查看详情

使用线性回归识别手写阿拉伯数字mnist数据集(代码片段)

学习了tensorflow的线性回归。首先是一个sklearn中makeregression数据集,对其进行线性回归训练的例子。来自腾讯云实验室importtensorflowastfimportnumpyasnpclasslinearRegressionModel:def__init__(self,x_dimen):self.x_dimen=x_dimenself._index_in_epoch 查看详情

tensorflow简单实例(代码片段)

TF手写体识别简单实例:TensorFlow很适合用来进行大规模的数值计算,其中也包括实现和训练深度神经网络模型。下面将介绍TensorFlow中模型的基本组成部分,同时将构建一个CNN模型来对MNIST数据集中的数字手写体进行识别。基本设... 查看详情