tensorflow学习笔记mnist手写数字识别(代码片段)

準提童子 準提童子     2022-12-13     808

关键词:

MNIST是机器学习中的Hello world,前期准备要了解Softmax (multinomial logistic ) regression

MNIST的是一个简单的计算机视觉数据集,它包含一系列手写数字图片,我们将训练一个模型识别图片中的数字

我们本次目的不是训练一个精准模型,以达到稳定的高性能,而是学会如何使用Tensorflow解决简单的问题

本次使用一个简单的模型Softmax Regression,代码是简单的,所有有趣的事情都发生在几行代码里,明白背后中TensorFlow是如何工作的,以及机器学习的核心概念

本文完成以下任务:

1.了解MNIST数据以及Softmax Regression

2.构建一个模型,根据图片的像素识别数字

3.基于上千个样例,使用TensorFlow训练模型识别数字

4.使用测试数据检验正确性

MINIST 数据集

MNIST的数据在机器学习大牛Yann leCun的个人网站上有http://yann.lecun.com/exdb/mnist/

可以使用代码下载数据集:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
数据集被分为三个部分:55000个 数据点data points作为训练mnist.train,10000作为测试mnist.test,5000作为验证mnist.validation

每个MNIST数据点有两个部分:手写的数字图片和其相应的标签,将图片作为x, 标签作为y

每个图片为28*28像素,我们可以将其抽象为数字:


扁平化数据为一个28*28=784大小的向量,如何扁平化是无关紧要的,只要能保证所有图片的扁平化方式一致

按照这种方式,图片被抽象为784维度的向量空间

扁平化数据的过程,丢失图片的2D属性,这是不是不太好?答案是肯定的,后边我们会涉及好的方法,但是在此我们使用简单的方式

处理后的结果是,mnist.train.imgages是一个张量tensor with shape [55000,784], 对于图片中的特定像素,tensor中每个条目中的值为0或1

每个图片有一个相应的标签,为0-9,每个标签抽象为“one-hot vectors” (指只有一个维度为1其他维度都为0的向量),

这样mnist.train.labels就是一个tensor with shape[55000,10]

Softmax Regression

主要讲解模型Softmax 回归

定义模型

在python数值计算中,经常会使用到一些别的库,像numpy,在python外进行昂贵的操作,如矩阵计算。但是Python内外的数据交互会使得高的开销,为避免这些,Tensorflow使用先描述计算图,后再一起计算的方式。

声明placeholder参数化的从外部输入数据,None代表维度可以是任意值

x=tf.placeholder(tf.float32,[None,784])
定义 权重weight偏差biase

W=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))
将W和b初始化为全0张量,这是无关紧要的,因为机器学习过程中会自动修改他们

W定义为一个tensor with [784,10],是因为我们要将784维的向量转换为10维的。

实现模型:

y=tf.nn.softmax(tf.matmul(x,W)+b)
使用 tf.matmul(x,W)实现x和W的相乘。仅此一行代码就定义了模型!

训练

为了训练模型,我们需要明确什么样的情况意味着模型好, 交叉熵cross-entropy决定模型损失(Softmax模型通过最小化交叉熵来训练)

存放真实值

y_=tf.placeholder(tf.float32,[None,10])
交叉熵计算:

cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y), reduction_indices=[1]))
tf.log计算y中元素的对数,tf.reduce_sum计算y中第2维元素的相加(y为tensor with shape[None, 10]),因为参数reduction_indices=[1]
最后tf.reduce_mean计算平均值,在源代码中我们不使用该方程,因为它数字上不是稳定的

对于非规范化的逻辑,使用tf.nn.softmax_cross_entropy_with_logits

tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
使用选择的优化算法修改变量,减少损失

train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
我们使用 学习率learning rate为0.5的 梯度下降算法gradient descent algorithm
Tensorflow提供了很多别的优化算法,使用其中一个只需调整一行代码,非常简单

现在可以运行模型,通过InteractiveSession

sess=tf.InteractiveSession()
初始化变量

tf.global_variables_initializer().run()
训练1000次

for _ in range(1000):
    batch_xs,batch_ys=mnist.train.next_batch(100)
    sess.run(train_step,feed_dict=x:batch_xs,y_:batch_ys)
循环中的每一步,随机取100个数据点为一个批次,作为训练。

使用小批次的随机数据称为随机训练stochastic training,理想化的来讲,应该使用所有的为每一步训练,但是这样是昂贵的,

所以,每次使用一个不同子集,这样运算成本低廉,而且同样有效

评估模型

tf.argmax是一个非常有用的方法,它给出张量tensor中沿着某个轴最高条目的索引(未明白),例如tf.argmax(y,1)是模型预测的输出label,tf.argmax(y_,1)是真实的label,

使用tf.equal检验预测与真实是否相匹配

correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))

tf.equal返回一列boolean, 我们转换为0,1,计算平均值:

accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
最后运行计算图获取精确值:

print(sess.run(accuracy,feed_dict=x:mnist.test.images,y_:mnist.test.labels))
准确度为92%

这个结果是否足够好呢?不是的,事实上是非常low的,这是因为我们使用了一个简单的模型,稍微改变可以达到97%,最好的模型可以达到99.7%

重要的是我们在这个模型中学到的东西。

完整代码:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

FLAGS = None


def main(_):
  # Import data
  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

  # Create the model
  x = tf.placeholder(tf.float32, [None, 784])
  W = tf.Variable(tf.zeros([784, 10]))
  b = tf.Variable(tf.zeros([10]))
  y = tf.matmul(x, W) + b

  # Define loss and optimizer
  y_ = tf.placeholder(tf.float32, [None, 10])

  # The raw formulation of cross-entropy,
  #
  #   tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
  #                                 reduction_indices=[1]))
  #
  # can be numerically unstable.
  #
  # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
  # outputs of 'y', and then average across the batch.
  cross_entropy = tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
  train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

  sess = tf.InteractiveSession()
  tf.global_variables_initializer().run()
  # Train
  for _ in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict=x: batch_xs, y_: batch_ys)

  # Test trained model
  correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print(sess.run(accuracy, feed_dict=x: mnist.test.images,
                                      y_: mnist.test.labels))

if __name__ == '__main__':
  # argparse模块,python用于解析命令行参数
  parser = argparse.ArgumentParser()
  # 指定数据集加载目录为MNIST_data
  parser.add_argument('--data_dir', type=str, default='MNIST_data',
                      help='Directory for storing input data')
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)


tensorflow基础学习五:mnist手写数字识别

MNIST数据集介绍:fromtensorflow.examples.tutorials.mnistimportinput_data#载入MNIST数据集,如果指定地址下没有已经下载好的数据,tensorflow会自动下载数据mnist=input_data.read_data_sets(‘.‘,one_hot=True)#打印Trainingdatasize:55000。print("Train 查看详情

学习笔记tf024:tensorflow实现softmaxregression(回归)识别手写数字

TensorFlow实现SoftmaxRegression(回归)识别手写数字。MNIST(MixedNationalInstituteofStandardsandTechnologydatabase),简单机器视觉数据集,28X28像素手写数字,只有灰度值信息,空白部分为0,笔迹根据颜色深浅取[0,1],784维,丢弃二维空间信息,目标... 查看详情

tensorflow学习笔记:mnist机器学习入门

...,首先从深度学习的入门MNIST入手。通过这个例子,了解Tensorflow的工作流程和机器学习的基本概念。一 MNIST数据集    MNIST是入门级的计算机视觉数据集,包含了各种手写数字的图片。在这个例子中就是通过机... 查看详情

tensorflow2.0入门教程实战案例

中文文档TensorFlow2/2.0中文文档知乎专栏欢迎关注知乎专栏 https://zhuanlan.zhihu.com/geektutu一、实战教程之强化学习TensorFlow2.0(九)-强化学习70行代码实战PolicyGradientTensorFlow2.0(八)-强化学习DQN玩转gymMountainCarTensorFlow2.0(七)-强化学习Q-Le... 查看详情

深度学习笔记:基于keras库的mnist手写数字识别(代码片段)

...f(1)作者正式Keras的作者)的,后来被Google收购并集成到Tensorflow中了。2数据加载和确认¶        MNIST数据集为Keras内置数据集,当然所谓的内置数据集并不是说数据文件本身内置于Tensorflow/Keras库,而是说其中封装了m... 查看详情

tensorflow之手写数字识别mnist

官方文档: MNISTForMLBeginners- https://www.tensorflow.org/get_started/mnist/beginners DeepMNISTforExperts- https://www.tensorflow.org/get_started/mnist/pros 版本: TensorFlow1.2 查看详情

tensorflow实践mnist手写数字识别

minst数据集                        tensorflow的文档中就自带了mnist手写数字识别的例子,是一个很经典也比较简单的入门tensorflow的例子,非常值得自己动 查看详情

android+tensorflow+cnn+mnist手写数字识别实现

SyncHere  查看详情

android+tensorflow+cnn+mnist手写数字识别实现

SyncHere  查看详情

tensorflow入门实战|第1周:实现mnist手写数字识别(代码片段)

...语言环境:Python3.6.5编译器:jupyternotebook深度学习环境:TensorFlow2文章目录一、前期工作1.设置GPU(如果使用的是CPU可以忽略这步)2.导入数据3.归一化4.可视化图片5.调整图片格式二、构建CNN网络模型三、编译模型四、训练模型五... 查看详情

一文全解:利用谷歌深度学习框架tensorflow识别手写数字图片(初学者篇)

笔记整理者:王小草 笔记整理时间2017年2月24日原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&locationNum=5Tensorflow官方英文文档地址:https://www.tensorflow.org/get_started/mnist/beginners&n 查看详情

tensorflow入门实战|第1周:实现mnist手写数字识别(代码片段)

...语言环境:Python3.6.5编译器:jupyternotebook深度学习环境:TensorFlow2文章目录一、前期工作1.设置GPU(如果使用的是CPU可以忽略这步)2.导入数据3.归一化4.可视化图片5.调整图片格式二、构建CNN网络模型三、编译模型四、训练模型五... 查看详情

mnist手写数字识别tensorflow(代码片段)

Mnist手写数字识别Tensorflow任务目标了解mnist数据集搭建和测试模型编辑环境操作系统:Win10python版本:3.6集成开发环境:pycharmtensorflow版本:1.*了解mnist数据集mnist数据集:mnist数据集下载地址??MNIST数据集来自美国国家标准与技术研究所,... 查看详情

tensorflow学习笔记三:实例数据下载与读取

...ist手写数字分类识别,因此我们应该先下载这个数据集。tensorflow提供一个input_data.py文件,专门用于下载mnist数据,我们直接调用就可以了,代码如下:importtensorflow.examples.tutorials.mnist.input_datamnist=input_data.read_data_sets("MNIST_data/" 查看详情

第一周:mnist手写数字识别|365天深度学习训练营

...mnist手写数字识别难度:新手入门语言:Python3、TensorFlow2时间:7月25-7月29日🍺要求:跑通程序了解深度学习是什么参考文章:🔗深度学习100例-卷积神经网络(CNN)实现mnist手写数字识别|第1天 查看详情

第一周:mnist手写数字识别|365天深度学习训练营

...mnist手写数字识别难度:新手入门语言:Python3、TensorFlow2时间:7月25-7月29日🍺要求:跑通程序了解深度学习是什么参考文章:🔗深度学习100例-卷积神经网络(CNN)实现mnist手写数字识别|第1天 查看详情

使用tensorflow和mnist识别自己手写的数字

#!/usr/bin/envpython3fromtensorflow.examples.tutorials.mnistimportinput_datamnist=input_data.read_data_sets(‘MNIST_data‘,one_hot=True)importtensorflowastfsess=tf.InteractiveSession()x=tf.placeholder(t 查看详情

tensorflow之mnist手写数字识别:分类问题(代码片段)

整体代码:#数据读取importtensorflowastfimportmatplotlib.pyplotaspltimportnumpyasnpfromtensorflow.examples.tutorials.mnistimportinput_datamnist=input_data.read_data_sets("MNIST_data/",one_hot=True)#定义待输入数据的占位符# 查看详情