mnist手写数字识别tensorflow(代码片段)

lyhlive lyhlive     2022-12-07     611

关键词:

Mnist手写数字识别 Tensorflow

任务目标

  • 了解mnist数据集
  • 搭建和测试模型

编辑环境

操作系统:Win10
python版本:3.6
集成开发环境:pycharm
tensorflow版本:1.*


了解mnist数据集

mnist数据集:mnist数据集下载地址
??MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据.
??图片是以字节的形式进行存储, 我们需要把它们读取到 NumPy array 中, 以便训练和测试算法。
读取mnist数据集

mnist = input_data.read_data_sets("mnist_data", one_hot=True)

搭建和测试模型

代码

def getMnistModel(savemodel,is_train):
    """
    :param savemodel:  模型保存路径
    :param is_train: true为训练,false为测试模型
    :return:None
    """
    mnist = input_data.read_data_sets("mnist_data", one_hot=True)
    with tf.variable_scope("data"):
        x = tf.placeholder(tf.float32,shape=[None,784]) # 784=28*28*1 宽长为28,单通道图片
        y_true = tf.placeholder(tf.int32,shape=[None,10]) # 10个类别

    with tf.variable_scope("conv1"):
        w_conv1 = tf.Variable(tf.random_normal([10,10,1,32])) # 10*10的卷积核  1个通道的输入图像  32个不同的卷积核,得到32个特征图
        b_conv1 = tf.Variable(tf.constant(0.0,shape=[32]))
        x_reshape = tf.reshape(x,[-1,28,28,1]) # n张 28*28 的单通道图片
        conv1 = tf.nn.relu(tf.nn.conv2d(x_reshape,w_conv1,strides=[1,1,1,1],padding="SAME")+b_conv1) #[1, 1, 1, 1] 中间2个1,卷积每次滑动的步长 padding=‘SAME‘ 边缘自动补充
        pool1 = tf.nn.max_pool(conv1,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME") # 池化窗口为[1,2,2,1] 中间2个2,池化窗口每次滑动的步长  padding="SAME" 考虑边界,如果不够用 用0填充
    with tf.variable_scope("conv2"):
        w_conv2 = tf.Variable(tf.random_normal([10,10,32,64]))
        b_conv2 = tf.Variable(tf.constant(0.0,shape=[64]))
        conv2 = tf.nn.relu(tf.nn.conv2d(pool1,w_conv2,strides=[1,1,1,1],padding="SAME")+b_conv2)
        pool2 = tf.nn.max_pool(conv2,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
    with tf.variable_scope("fc"):
        w_fc = tf.Variable(tf.random_normal([7*7*64,10])) # 经过两次卷积和池化 28 * 28/(2+2) = 7 * 7
        b_fc = tf.Variable(tf.constant(0.0,shape=[10]))
        xfc_reshape = tf.reshape(pool2,[-1,7*7*64])
        y_predict = tf.matmul(xfc_reshape,w_fc)+b_fc
    with tf.variable_scope("loss"):
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict))
    with tf.variable_scope("optimizer"):
        train_op = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
    with tf.variable_scope("acc"):
        equal_list = tf.equal(tf.arg_max(y_true,1),tf.arg_max(y_predict,1))
        accuracy = tf.reduce_mean(tf.cast(equal_list,tf.float32))

    # tensorboard
    # tf.summary.histogram用来显示直方图信息
    # tf.summary.scalar用来显示标量信息
    # Summary:所有需要在TensorBoard上展示的统计结果
    tf.summary.histogram("weight",w_fc)
    tf.summary.histogram("bias",b_fc)
    tf.summary.scalar("loss",loss)
    tf.summary.scalar("acc",accuracy)
    merged = tf.summary.merge_all()
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(init)
        filewriter = tf.summary.FileWriter("tfboard",graph=sess.graph)
        if is_train:
            for i in range(500):
                x_train,y_train = mnist.train.next_batch(100)
                sess.run(train_op,feed_dict=x:x_train,y_true:y_train)
                summary = sess.run(merged,feed_dict=x:x_train,y_true:y_train)
                filewriter.add_summary(summary,i)
                print("第%d训练,准确率为%f"%(i+1,sess.run(accuracy,feed_dict=x:x_train,y_true:y_train)))
            saver.save(sess,savemodel)
        else:
            count = 0.0
            epochs = 100
            saver.restore(sess, savemodel)
            for i in range(epochs):
                x_test, y_test = mnist.train.next_batch(1)
                print("第%d张图片,真实值为:%d预测值为:%d" % (i + 1,
                                                 tf.argmax(sess.run(y_true, feed_dict=x: x_test, y_true: y_test),
                                                           1).eval(),
                                                 tf.argmax(
                                                     sess.run(y_predict, feed_dict=x: x_test, y_true: y_test),
                                                     1).eval()
                                                 ))
                if (tf.argmax(sess.run(y_true, feed_dict=x: x_test, y_true: y_test), 1).eval() == tf.argmax(
                        sess.run(y_predict, feed_dict=x: x_test, y_true: y_test), 1).eval()):
                    count = count + 1
            print("正确率为 %.2f " % float(count * 100 / epochs) + "%")

查看Tensorboard

?? TensorBoard是TensorFlow下的一个可视化的工具,能够帮助我们在训练大规模神经网络过程中出现的复杂且不好理解的运算。TensorBoard能展示你训练过程中绘制的图像、网络结构等。

  • 复制你训练后log文件保存的位置
    技术图片
  • 打开终端
    输入命令 tensorboard --logdir=D:GongChengmnist fboard
    tensorboard --logdir= + 文件夹的路径
    技术图片
  • 用浏览器打开上述地址
    技术图片

完整代码

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

def getMnistModel(savemodel,is_train):
    """
    :param savemodel:  模型保存路径
    :param is_train: True为训练,False为测试模型
    :return:None
    """
    mnist = input_data.read_data_sets("mnist_data", one_hot=True)
    with tf.variable_scope("data"):
        x = tf.placeholder(tf.float32,shape=[None,784]) # 784=28*28*1 宽长为28,单通道图片
        y_true = tf.placeholder(tf.int32,shape=[None,10]) # 10个类别

    with tf.variable_scope("conv1"):
        w_conv1 = tf.Variable(tf.random_normal([10,10,1,32])) # 10*10的卷积核  1个通道的输入图像  32个不同的卷积核,得到32个特征图
        b_conv1 = tf.Variable(tf.constant(0.0,shape=[32]))
        x_reshape = tf.reshape(x,[-1,28,28,1]) # n张 28*28 的单通道图片
        conv1 = tf.nn.relu(tf.nn.conv2d(x_reshape,w_conv1,strides=[1,1,1,1],padding="SAME")+b_conv1) #[1, 1, 1, 1] 中间2个1,卷积每次滑动的步长 padding=‘SAME‘ 边缘自动补充
        pool1 = tf.nn.max_pool(conv1,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME") # 池化窗口为[1,2,2,1] 中间2个2,池化窗口每次滑动的步长  padding="SAME" 考虑边界,如果不够用 用0填充
    with tf.variable_scope("conv2"):
        w_conv2 = tf.Variable(tf.random_normal([10,10,32,64]))
        b_conv2 = tf.Variable(tf.constant(0.0,shape=[64]))
        conv2 = tf.nn.relu(tf.nn.conv2d(pool1,w_conv2,strides=[1,1,1,1],padding="SAME")+b_conv2)
        pool2 = tf.nn.max_pool(conv2,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
    with tf.variable_scope("fc"):
        w_fc = tf.Variable(tf.random_normal([7*7*64,10])) # 经过两次卷积和池化 28 * 28/(2+2) = 7 * 7
        b_fc = tf.Variable(tf.constant(0.0,shape=[10]))
        xfc_reshape = tf.reshape(pool2,[-1,7*7*64])
        y_predict = tf.matmul(xfc_reshape,w_fc)+b_fc
    with tf.variable_scope("loss"):
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict))
    with tf.variable_scope("optimizer"):
        train_op = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
    with tf.variable_scope("acc"):
        equal_list = tf.equal(tf.arg_max(y_true,1),tf.arg_max(y_predict,1))
        accuracy = tf.reduce_mean(tf.cast(equal_list,tf.float32))

    # tensorboard
    # tf.summary.histogram用来显示直方图信息
    # tf.summary.scalar用来显示标量信息
    # Summary:所有需要在TensorBoard上展示的统计结果
    tf.summary.histogram("weight",w_fc)
    tf.summary.histogram("bias",b_fc)
    tf.summary.scalar("loss",loss)
    tf.summary.scalar("acc",accuracy)
    merged = tf.summary.merge_all()
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(init)
        filewriter = tf.summary.FileWriter("tfboard",graph=sess.graph)
        if is_train:
            for i in range(500):
                x_train,y_train = mnist.train.next_batch(100)
                sess.run(train_op,feed_dict=x:x_train,y_true:y_train)
                summary = sess.run(merged,feed_dict=x:x_train,y_true:y_train)
                filewriter.add_summary(summary,i)
                print("第%d训练,准确率为%f"%(i+1,sess.run(accuracy,feed_dict=x:x_train,y_true:y_train)))
            saver.save(sess,savemodel)
        else:
            count = 0.0
            epochs = 100
            saver.restore(sess, savemodel)
            for i in range(epochs):
                x_test, y_test = mnist.train.next_batch(1)
                print("第%d张图片,真实值为:%d预测值为:%d" % (i + 1,
                                                 tf.argmax(sess.run(y_true, feed_dict=x: x_test, y_true: y_test),
                                                           1).eval(),
                                                 tf.argmax(
                                                     sess.run(y_predict, feed_dict=x: x_test, y_true: y_test),
                                                     1).eval()
                                                 ))
                if (tf.argmax(sess.run(y_true, feed_dict=x: x_test, y_true: y_test), 1).eval() == tf.argmax(
                        sess.run(y_predict, feed_dict=x: x_test, y_true: y_test), 1).eval()):
                    count = count + 1
            print("正确率为 %.2f " % float(count * 100 / epochs) + "%")
if __name__ == ‘__main__‘:
    modelPath = "model/mnist_model"
    getMnistModel(modelPath,True)












第三节,tensorflow使用cnn实现手写数字识别(代码片段)

...括以下几块内容[1]:导入数据,即测试集和验证集[2]:引入tensorflow启动InteractiveSession(比session更灵活)[3]:定义两个初始化w和b的函数,方便后续操作[4]:定义卷积和池化函数,这里卷积采用pad 查看详情

tensorflow学习笔记mnist手写数字识别(代码片段)

...模型,以达到稳定的高性能,而是学会如何使用Tensorflow解 查看详情

tensorflow与flask结合打造手写体数字识别(代码片段)

TensorFlow与Flask结合打造手写体数字识别主要步骤:获取mnist数据集分别创建regression和convolution的模型,设置对应的计算方式、参数等信息创建regression、convolution获取数据,调用对应模型进行训练、测试最后保存对应模型创建mnist... 查看详情

tensorflow之mnist手写数字识别:分类问题(代码片段)

整体代码:#数据读取importtensorflowastfimportmatplotlib.pyplotaspltimportnumpyasnpfromtensorflow.examples.tutorials.mnistimportinput_datamnist=input_data.read_data_sets("MNIST_data/",one_hot=True)#定义待输入数据的占位符# 查看详情

tensorflow-神经网络识别手写数字(代码片段)

数据下载连接:http://yann.lecun.com/exdb/mnist/下载t10k-images-idx3-ubyte.gz;t10k-labels-idx1-ubyte.gz;train-images-idx3-ubyte.gz;train-labels-idx1-ubyte.gz简单神经网络识别手写数字importtensorf 查看详情

深度学习手写数字识别tensorflow2实验报告(代码片段)

实验一:手写数字识别一、实验目的利用深度学习实现手写数字识别,当输入一张手写图片后,能够准确的识别出该图片中数字是几。输出内容是0、1、2、3、4、5、6、7、8、9的其中一个。二、实验原理(1)采... 查看详情

tensorflow从入门到精通——手写数字识别(代码片段)

importtensorflowastfprint(tf.__version__)2.6.0一、数据集mnist=tf.keras.datasets.mnist(train_images,train_labels),(test_images,test_labels)=mnist.load_data()train_images.shape,train_labels.shape((60 查看详情

tensorflow项目实战一:mnist手写数字识别(代码片段)

...测结果,然后计算损失,进行训练。  代码如下:importtensorflowastffromtensorflow.examples.tutorials.mnistimport 查看详情

07训练tensorflow识别手写数字

   打开PythonShell,输入以下代码:1importtensorflowastf2fromtensorflow.examples.tutorials.mnistimportinput_data34#获取数据(如果存在就读取,不存在就下载完再读取)5mnist=input_data.read_data_sets("MNIST_data/",one_hot=Tr 查看详情

tensorflow暑期实践——基于单个神经元的手写数字识别(全部代码)(代码片段)

#coding:utf-8importtensorflowastfimportosos.environ["CUDA_VISIBLE_DEVICES"]="-1"print(tf.__version__)print(tf.test.is_gpu_available())fromtensorflow.examples.tutorials.mnistimportinput_datamnist=input 查看详情

tensorflow入门实战|第1周:实现mnist手写数字识别(代码片段)

...语言环境:Python3.6.5编译器:jupyternotebook深度学习环境:TensorFlow2文章目录一、前期工作1.设置GPU(如果使用的是CPU可以忽略这步)2.导入数据3.归一化4.可视化图片5.调整图片格式二、构建CNN网络模型三、编译模型四、训练模型五... 查看详情

tensorflow快速入门2--实现手写数字识别

Tensorflow快速入门2–实现手写数字识别环境:虚拟机ubuntun16.0.4Tensorflow(仅使用cpu版)Tensorflow安装见:http://blog.csdn.net/yhhyhhyhhyhh/article/details/54429034或者:http://www.tensorfly.cn/tfdoc/get_started/os_setup.html本文将利用Ten 查看详情

tensorflow入门实战|第1周:实现mnist手写数字识别(代码片段)

...语言环境:Python3.6.5编译器:jupyternotebook深度学习环境:TensorFlow2文章目录一、前期工作1.设置GPU(如果使用的是CPU可以忽略这步)2.导入数据3.归一化4.可视化图片5.调整图片格式二、构建CNN网络模型三、编译模型四、训练模型五... 查看详情

tensorflow1.x代码实战系列:mnist手写数字识别(代码片段)

相关资料https://github.com/aymericdamien/TensorFlow-Examples实战1:MNSIT手写数字识别这里使用的是普通的DNN模型实现MNIST识别,如果想提高精度,可以自行使用CNN模型。'''AlogisticregressionlearningalgorithmexampleusingTensorFlowlibra... 查看详情

tensorflow实现cnn简单手写数字识别(python)(代码片段)

...行数:112行(主程序)开发环境:Python3.9、OpenCV4.5、Tensorflow2.7该源码均通过亲自测试可正常运行下载地址:点击下载简要概述:主要使用到的库:Numpy,Pygame,Tensorflow训练模型用到的是minist数据集由于时... 查看详情

tensorflow之手写数字识别mnist

官方文档: MNISTForMLBeginners- https://www.tensorflow.org/get_started/mnist/beginners DeepMNISTforExperts- https://www.tensorflow.org/get_started/mnist/pros 版本: TensorFlow1.2 查看详情

tensorflow实践mnist手写数字识别

minst数据集                        tensorflow的文档中就自带了mnist手写数字识别的例子,是一个很经典也比较简单的入门tensorflow的例子,非常值得自己动 查看详情

第二节,tensorflow使用前馈神经网络实现手写数字识别(代码片段)

一感知器    感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695    感知器(Perceptron)是二分类的线性分类模型,其输入为实例的特征向量,输出为实例的类别,取+1和-1。这种算法的局限... 查看详情