tensorflow之mnist手写数字识别:分类问题(代码片段)

lsm-boke lsm-boke     2023-01-18     189

关键词:

整体代码:

#数据读取
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)

#定义待输入数据的占位符
#mnist中每张照片共有28*28=784个像素点
x = tf.placeholder(tf.float32,[None,784],name="X")

#0-9一共10个数字=>10个类别
y = tf.placeholder(tf.float32,[None,10],name="Y")

#定义模型变量
#以正态分布的随机数初始化权重W,以常数0初始化偏置b
#在神经网络中,权值W的初始值通常设为正态分布的随机数,偏置项b的初始值通常也设置为正态分布的随机数或常数。
W = tf.Variable(tf.random_normal([784,10],name="W"))
b = tf.Variable(tf.zeros([10]),name="b")

#用单个神经元构建神经网络
forward=tf.matmul(x,W) + b   #前向计算

#结果分类
#当我们处理多分类任务的时候,通常需要使用Softmax Regression模型。Softmax会对每一类别估算出一个概率。
#工作原理:将判定为某一类的特征相加,然后将这些特征转化为判定是这一类的概率
pred = tf.nn.softmax(forward)     #Softmax分类

#设置训练参数
train_epochs = 120     #训练轮数
batch_size = 120      #单次训练样本数(批次大小)
total_batch = int(mnist.train.num_examples/batch_size)              #一轮训练有多少批次
display_step = 1   #显示粒度
learning_rate = 0.01             #学习率 

#概率估算值需要将预测输出值控制在[0,1]区间内。二元分类问题的目标是正确预测两个可能标签中的一个
#逻辑回归可以用于处理这类问题。二元逻辑回归的损失函数一般采用对数损失函数
#多元分类:逻辑回归可生成介于0到1.0之间的小数。Softmax将这一想法延伸到多类别领域。
#在多类别问题中,Softmax会为每个类别分配一个用小数表示的概率。这些用小数表示的概率相加之和必须是1.0

#交叉熵损失函数:交叉熵是一个信息论的概念,它原来是用来估算平均编码长度的。
#交叉熵刻画的是两个概率分布之间的距离,p代表正确答案,q代表的预测值,交叉熵越小,两个概率的分布越接近
#定义损失函数
loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))    #交叉熵

#选择优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)     #梯度下降优化器

#定义准确率
# 检查预测类别tf.argmax(pred,1)与实际类别tf.argmax(y,1)的匹配情况
#argmax()将数组中最大值的下标取出来
correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))

#准确率,将布尔值转化为浮点数,并计算平均值    tf.cast()将布尔值投射成浮点数
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

#声明会话,初始化变量
sess = tf.Session()
init = tf.global_variables_initializer()   #变量初始化
sess.run(init)

#训练模型
for epoch in range(train_epochs):
    for batch in range(total_batch):
        xs,ys = mnist.train.next_batch(batch_size)  #读取批次数据
        sess.run(optimizer,feed_dict=x:xs,y:ys)   #执行批次训练
        
    #total_batch个批次训练完成后,使用验证数据计算误差与准确率,验证集没有分批
    loss,acc = sess.run([loss_function,accuracy],feed_dict=x:mnist.validation.images,y:mnist.validation.labels)
    
    #打印训练过程中的详细信息
    if (epoch+1) % display_step == 0:
        print("Train Epoch:",%02d%(epoch+1),"Loss=",":.9f".format(loss),"Accuracy=",":.4f".format(acc))
        
print("Train Finished!")
        
#评估模型
#完成训练后,在测试集上评估模型的准确率
accu_test = sess.run(accuracy,feed_dict=x:mnist.test.images,y:mnist.test.labels)
print("Test Accuracy:",accu_test)
#完成训练后,在验证集上评估模型的准确率
accu_validation = sess.run(accuracy,feed_dict=x:mnist.validation.images,y:mnist.validation.labels)
print("Test Accuracy:",accu_validation)
#完成训练后,在训练集上评估模型的准确率
accu_train = sess.run(accuracy,feed_dict=x:mnist.train.images,y:mnist.train.labels)
print("Test Accuracy:",accu_train)

#应用模型
#在建立模型并进行训练后,若认为准确率可以接受,则可以使用此模型进行预测
#由于pred预测结果是one_hot编码格式,所以需要转换成0~9数字
prediction_result = sess.run(tf.argmax(pred,1),feed_dict=x:mnist.test.images)

#查看预测结果中的前10项
prediction_result[0:10]

#定义可视化函数
def plot_images_labels_prediction(images,labels,prediction,index,num=10):  #参数: 图形列表,标签列表,预测值列表,从第index个开始显示,缺省一次显示10幅
    fig = plt.gcf()             #获取当前图表,Get Current Figure
    fig.set_size_inches(10,12)    #1英寸等于2.45cm
    if num > 25 :      #最多显示25个子图
        num = 25
    for i in range(0,num):
        ax = plt.subplot(5,5,i+1)   #获取当前要处理的子图
        ax.imshow(np.reshape(images[index],(28,28)), cmap = binary)              #显示第index个图像
        title = "labels="+str(np.argmax(labels[index]))              #构建该图上要显示的title信息
        if len(prediction)>0:
            title += ",predict="+str(prediction[index])
            
        ax.set_title(title,fontsize=10)    #显示图上的title信息
        ax.set_xticks([])           #不显示坐标轴
        ax.set_yticks([])
        index += 1
    plt.show()
#可视化预测结果
plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,10,10)

 

tensorflow之手写数字识别mnist

官方文档: MNISTForMLBeginners- https://www.tensorflow.org/get_started/mnist/beginners DeepMNISTforExperts- https://www.tensorflow.org/get_started/mnist/pros 版本: TensorFlow1.2 查看详情

tensorflow项目实战一:mnist手写数字识别(代码片段)

...测结果,然后计算损失,进行训练。  代码如下:importtensorflowastffromtensorflow.examples.tutorials.mnistimport 查看详情

tensorflow2.0入门教程实战案例

中文文档TensorFlow2/2.0中文文档知乎专栏欢迎关注知乎专栏 https://zhuanlan.zhihu.com/geektutu一、实战教程之强化学习TensorFlow2.0(九)-强化学习70行代码实战PolicyGradientTensorFlow2.0(八)-强化学习DQN玩转gymMountainCarTensorFlow2.0(七)-强化学习Q-Le... 查看详情

tensorflow实践mnist手写数字识别

minst数据集                        tensorflow的文档中就自带了mnist手写数字识别的例子,是一个很经典也比较简单的入门tensorflow的例子,非常值得自己动 查看详情

android+tensorflow+cnn+mnist手写数字识别实现

SyncHere  查看详情

android+tensorflow+cnn+mnist手写数字识别实现

SyncHere  查看详情

tensorflow基础学习五:mnist手写数字识别

MNIST数据集介绍:fromtensorflow.examples.tutorials.mnistimportinput_data#载入MNIST数据集,如果指定地址下没有已经下载好的数据,tensorflow会自动下载数据mnist=input_data.read_data_sets(‘.‘,one_hot=True)#打印Trainingdatasize:55000。print("Train 查看详情

mnist手写数字识别tensorflow(代码片段)

Mnist手写数字识别Tensorflow任务目标了解mnist数据集搭建和测试模型编辑环境操作系统:Win10python版本:3.6集成开发环境:pycharmtensorflow版本:1.*了解mnist数据集mnist数据集:mnist数据集下载地址??MNIST数据集来自美国国家标准与技术研究所,... 查看详情

tensorflow学习笔记mnist手写数字识别(代码片段)

...模型,以达到稳定的高性能,而是学会如何使用Tensorflow解 查看详情

使用tensorflow和mnist识别自己手写的数字

#!/usr/bin/envpython3fromtensorflow.examples.tutorials.mnistimportinput_datamnist=input_data.read_data_sets(‘MNIST_data‘,one_hot=True)importtensorflowastfsess=tf.InteractiveSession()x=tf.placeholder(t 查看详情

tensorflow入门实战|第1周:实现mnist手写数字识别(代码片段)

...语言环境:Python3.6.5编译器:jupyternotebook深度学习环境:TensorFlow2文章目录一、前期工作1.设置GPU(如果使用的是CPU可以忽略这步)2.导入数据3.归一化4.可视化图片5.调整图片格式二、构建CNN网络模型三、编译模型四、训练模型五... 查看详情

tensorflow入门实战|第1周:实现mnist手写数字识别(代码片段)

...语言环境:Python3.6.5编译器:jupyternotebook深度学习环境:TensorFlow2文章目录一、前期工作1.设置GPU(如果使用的是CPU可以忽略这步)2.导入数据3.归一化4.可视化图片5.调整图片格式二、构建CNN网络模型三、编译模型四、训练模型五... 查看详情

07训练tensorflow识别手写数字

   打开PythonShell,输入以下代码:1importtensorflowastf2fromtensorflow.examples.tutorials.mnistimportinput_data34#获取数据(如果存在就读取,不存在就下载完再读取)5mnist=input_data.read_data_sets("MNIST_data/",one_hot=Tr 查看详情

tensorflow1.x代码实战系列:mnist手写数字识别(代码片段)

相关资料https://github.com/aymericdamien/TensorFlow-Examples实战1:MNSIT手写数字识别这里使用的是普通的DNN模型实现MNIST识别,如果想提高精度,可以自行使用CNN模型。'''AlogisticregressionlearningalgorithmexampleusingTensorFlowlibra... 查看详情

tensorflow从入门到精通——手写数字识别(代码片段)

importtensorflowastfprint(tf.__version__)2.6.0一、数据集mnist=tf.keras.datasets.mnist(train_images,train_labels),(test_images,test_labels)=mnist.load_data()train_images.shape,train_labels.shape((60 查看详情

基于tensorflow2.x实现bp神经网络,实践mnist手写数字识别(代码片段)

...,在很多资料中都会被用作深度学习的入门样例。在Tensorflow2.x中该数据集已被封装在了tf.keras.datasets工具包下,如果没有指定数据集的位置,并先前也没有使用过,会自动联网下载该,使该数据集使用起来更... 查看详情

基于tensorflow2.x实现bp神经网络,实践mnist手写数字识别(代码片段)

...,在很多资料中都会被用作深度学习的入门样例。在Tensorflow2.x中该数据集已被封装在了tf.keras.datasets工具包下,如果没有指定数据集的位置,并先前也没有使用过,会自动联网下载该,使该数据集使用起来更... 查看详情

ai常用框架和工具丨11.基于tensorflow(keras)+flask部署mnist手写数字识别至本地web

代码实例,基于TensorFlow+Flask部署MNIST手写数字识别至本地web,希望对您有所帮助。文章目录环境说明文件结构模型训练本地web创建实现效果环境说明操作系统:Windows10CUDA版本为:10.0cudnn版本为:7.6.5Python版本为:Python3.6.13tensorflow-gp... 查看详情