tensorflow学习笔记:mnist机器学习入门

Fzu_LJ Fzu_LJ     2022-08-14     658

关键词:

     学习深度学习,首先从深度学习的入门MNIST入手。通过这个例子,了解Tensorflow的工作流程和机器学习的基本概念。

一  MNIST数据集

     MNIST是入门级的计算机视觉数据集,包含了各种手写数字的图片。在这个例子中就是通过机器学习训练一个模型,以识别图片中的数字。

     MNIST数据集来自 http://yann.lecun.com/exdb/mnist/

     Tensorflow提供了一份python代码用于自动下载安装数据集。Tensorflow官方文档中的url打不开,在CSDN上找到了一个分享:http://download.csdn.net/detail/u010417185/9588647

     和官方有点不同的是,我直接把四个数据集下载下来,放在/tmp/mnist下,在项目文件中使用以下代码导入:

import input_data
import tensorflow as tf

mnist = input_data.read_data_sets("/tmp/mnist", one_hot=True)

     这里的数据集分为两个部分:60000的训练数据集(mnist.train)和10000的测试数据集(mnist.test),测试集的作用是帮助模型泛化。数据对应包含图片和标签,分别用mnist.train.images,mnist.train.lables,mnist.test.images,mnist.test.lables来表示。每张图片有28×28=784个像素点,因此训练图片mnist.train.images的张量表示为 [60000, 784],第一个纬度用于索引图片,第二纬度用于索引像素点。由于判断10个数字,这里采用热独,即one-hot-vectors,除了一位数字为1外其他纬度数字为0。例如判断数字为0则其表示为[1,0,0,0,0,0,0,0,0,0]。因此训练标签表示为[10000,10],第一纬度索引图片,第二纬度判断数字。

二  softmax回归介绍

   softmax模型可以给不同的对象分配概率。根据下图,对输入的x的加权求和,再分别加上一个偏置量,最后输入到softmax函数中:

 

      具体转换为公式,即:

三  实现回归模型

      首先进行模型的定义,如下:

x = tf.placeholder(tf.float32, [None, 784]) #使用占位符placeholder,第一维度可指定图片的数量是任意的
W = tf.Variable(tf.zeros([784,10]))  #初始化权值
b = tf.Variable(tf.zeros([10]))      #初始化偏置值
y = tf.nn.softmax(tf.matmul(x,W) + b)  #根据公式计算

四  训练模型

     选用的损失函数为交叉熵,其定义如下:

    其中y为预测的概率分布,y'为实际分布。

    代码如下:

y_ = tf.placeholder("float", [None,10])  #表示实际的分布
cross_entropy = -tf.reduce_sum(y_*tf.log(y))  #计算损失函数
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  #以梯度下降算法最小化损失函数
init = tf.initialize_all_variables()  #初始化所有变量
sess = tf.Session()  #定义会话
sess.run(init)   #初始化会话

for i in range(1000):   #开始训练,循环训练1000次
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

 

五  评估模型

    选用tf.argmax函数评估,它能给出某个tensor对象在某一维上的其数据最大值所在的索引值。由于标签向量是由0,1组成,因此最大值1所在的索引位置就是类别标签,比如tf.argmax(y,1)返回的是模型对于任一输入x预测到的标签值,而 tf.argmax(y_,1) 代表正确的标签,用 tf.equal 来检测预测是否与真实标签匹配(索引位置一样表示匹配)。

    代码如下:

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))  #评估
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))  #将结果转换为浮点数
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})  #输出

六  代码

import input_data
import tensorflow as tf

mnist = input_data.read_data_sets("/tmp/mnist", one_hot=True)

x = tf.placeholder(tf.float32, [None, 784]) #使用占位符placeholder,第一维度可指定图片的数量是任意的
W = tf.Variable(tf.zeros([784,10]))  #初始化权值
b = tf.Variable(tf.zeros([10]))      #初始化偏置值
y = tf.nn.softmax(tf.matmul(x,W) + b)  #根据公式计算
y_ = tf.placeholder("float", [None,10])  #表示实际的分布
cross_entropy = -tf.reduce_sum(y_*tf.log(y))  #计算损失函数
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  #以梯度下降算法最小化损失函数
init = tf.initialize_all_variables()  #初始化所有变量
sess = tf.Session()  #定义会话
sess.run(init)   #初始化会话

for i in range(1000):   #开始训练,循环训练1000次
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))  #评估
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))  #将结果转换为浮点数
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})  #输出
View Code

七  实验结果

    最终测试结果精确度在91%左右。

机器学习框架tensorflow数字识别mnist

SoftMax回归 http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92我们的训练集由  个已标记的样本构成: ,其中输入特征。(我们对符号的约定如下:特征向量  的维度为 ,其中  对应截距项。)... 查看详情

mnist机器学习入门

  在前一个博客中,我们已经对MNIST数据集和TensorFlow中MNIST数据集的载入有了基本的了解。本节将真正以TensorFlow为工具,写一个手写体数字识别程序,使用的机器学习方法是Softmax回归。一、Softmax回归的原理  Softmax回归是一... 查看详情

tensorflow之mnist解析

...年什么技术最火爆,无疑是google领衔的深度学习开源框架Tensorflow。本文简述一下深度学习的入门例子MNIST。深度学习简单介绍首先要简单区别几个概念:人工智能,机器学习,深度学习,神经网络。这几个词应该是出现的最为频... 查看详情

tensorflow学习笔记三:实例数据下载与读取

...ist手写数字分类识别,因此我们应该先下载这个数据集。tensorflow提供一个input_data.py文件,专门用于下载mnist数据,我们直接调用就可以了,代码如下:importtensorflow.examples.tutorials.mnist.input_datamnist=input_data.read_data_sets("MNIST_data/" 查看详情

tensorflow学习笔记

学习mnist_softmax.py向量:有大小和方向的量。标量:只有大小的量。向量的写法:起点a,重点b,则写为ab上面加一个剪头。如果是在一个直角坐标系,则可以用数对的形式表示。例如(2,3)每一手写数字样本图片都是28*28像素的... 查看详情

tensorflow学习笔记(对mnist经典例程的)的代码注释与理解

1#coding:utf-82#日期2017年9月4日环境Python3.5 TensorFlow1.3win10开发环境。3importtensorflowastf4fromtensorflow.examples.tutorials.mnistimportinput_data5importos678#基础的学习率9LEARNING_RATE_BASE=0.81011#学习率的衰减率12 查看详情

tensorflow学习笔记---2--dcgan代码学习

...rks)的网络结构。代码下载地址https://github.com/carpedm20/DCGAN-tensorflow注1:发现代码中以mnist为训练集的网络和以无标签数据集(以下简称unlabeled_dataset)为训练集的网络不同,结构有别。以下笔记主要针对 查看详情

tensorflow学习笔记五:mnist实例--卷积神经网络(cnn)

...我就分成几个部分来叙述。首先,下载并加载数据:importtensorflowastfimporttensorflow.examples.tutorials.mnist.input_dataasinput_datamnist=input_data.read_data 查看详情

tensflow框架学习之mnist机器学习入门

前言:初学TensorFlow和机器学习,MNIST算法的每条语句都不是很清楚,通过查阅资料,将每句代码的基本用法差不多理解了。希望能够帮助正在学习的你  [python]viewplaincopy <span style="font-size:32px;">from tensorflow.e... 查看详情

tensorflow学习笔记

importtensorflowastfimportnumpyasnpimportmathimporttensorflow.examples.tutorials.mnistasmnsess=tf.InteractiveSession()mnist=mn.input_data.read_data_sets("E:\Python35\Lib\site-packages\tensorflow\ 查看详情

机器学习-tensorflowsharp简单使用与knn识别mnist流程(代码片段)

机器学习是时下非常流行的话题,而Tensorflow是机器学习中最有名的工具包。TensorflowSharp是Tensorflow的C#语言表述。本文会对TensorflowSharp的使用进行一个简单的介绍。本文会先介绍Tensorflow的一些基本概念,然后实现一些基本操作例... 查看详情

tensorflow实战计算机视觉之mnist数据集

计算机视觉方向使用深度学习主要是卷积神经网络,可以参考这篇文章:零基础入门深度学习(4)-卷积神经网络MNIST机器学习入门:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html卷积神经网络改善MNIST数据集识别准确率M... 查看详情

学习笔记tf024:tensorflow实现softmaxregression(回归)识别手写数字

TensorFlow实现SoftmaxRegression(回归)识别手写数字。MNIST(MixedNationalInstituteofStandardsandTechnologydatabase),简单机器视觉数据集,28X28像素手写数字,只有灰度值信息,空白部分为0,笔迹根据颜色深浅取[0,1],784维,丢弃二维空间信息,目标... 查看详情

mnist机器学习入门(代码片段)

...下载MNIST数据集,并打印一些基本信息:pythondownload.py#从tensorflow.examples.tutorials.mnist引入模块,这是TensorFlow为了教学MNIST而提前编制的程序fromtensorflow.examples.tutorials.mnistimportinput_data#从MNIST_data/中读取MNIST数据。这条语句在数据不... 查看详情

机器学习进阶笔记之一|tensorflow安装与入门

原文链接:https://zhuanlan.zhihu.com/p/22410917TensorFlow是Google基于DistBelief进行研发的第二代人工智能学习系统,被广泛用于语音识别或图像识别等多项机器深度学习领域。其命名来源于本身的运行原理。Tensor(张量)意味着N维数组,Fl... 查看详情

tensorflow学习笔记入门(代码片段)

一、TensorFlow是什么?是谷歌开源的机器学习实现框架,本文从Python语言来理解学习Tensorflow以及机器学习的知识。TensorFlow的API主要分两个层次,核心层和基于核心层的高级API。核心层面向机器学习的研究人员,以... 查看详情

tensorflow笔记之mnist手写识别系列一

tensorflow笔记(四)之MNIST手写识别系列一版权声明:本文为博主原创文章,转载请指明转载地址http://www.cnblogs.com/fydeblog/p/7436310.html前言这篇博客将利用神经网络去训练MNIST数据集,通过学习到的模型去分类手写数字。我会将本篇... 查看详情

深度学习与tensorflow2.0卷积神经网络(cnn)(代码片段)

注:在很长一段时间,MNIST数据集都是机器学习界很多分类算法的benchmark。初学深度学习,在这个数据集上训练一个有效的卷积神经网络就相当于学习编程的时候打印出一行“HelloWorld!”。下面基于与MNIST数据集非常类似的... 查看详情