tensorflow编程基础之mnist手写识别实验+关于cross_entropy的理解(代码片段)

kerwins-ac kerwins-ac     2023-01-18     377

关键词:

好久没有静下心来写点东西了,最近好像又回到了高中时候的状态,休息不好,无法全心学习,恶性循环,现在终于调整的好一点了,听着纯音乐突然非常伤感,那些曾经快乐的大学时光啊,突然又慢慢的一下子出现在了眼前,不知道我大学的那些小伙伴们现在都怎么样了,考研的刚刚希望他考上,实习的菜头希望他早日脱离苦海,小瑞哥希望他早日出成果,范爷熊健研究生一定要过的开心啊!天哥也哥早日结婚领证!那些回不去的曾经的快乐的时光,你们都还好吗!

 

最近开始接触Tensorflow,可能是论文里用的是这个框架吧,其实我还是觉得pytorch更方便好用一些,仔细读了最简单的Mnist手写识别程序,觉得大同小异,关键要理解Tensorflow的思想,文末就写一下自己看交叉熵的感悟,絮叨了这么多开始写点代码吧!  1 # -*- coding: utf-8 -*-

  2 """
  3 Created on Sun Nov 11 16:14:38 2018
  4 
  5 @author: Yang
  6 """
  7 
  8 import tensorflow as tf 
  9 from tensorflow.examples.tutorials.mnist import input_data 
 10 
 11 mnist = input_data.read_data_sets("/MNIST_data",one_hot=True) #从input_data中读取数据集,使用one_hot编码
 12 
 13 import pylab #画图模块
 14 
 15 tf.reset_default_graph()#重置一下图 图代表了一个运算过程,包含了许多Variable和op,如果不重置一下图的话,可能会因为某些工具重复调用变量而报错
 16 
 17 x = tf.placeholder(tf.float32,[None,784])#占位符,方便用feed_dict进行注入操作
 18 y = tf.placeholder(tf.float32,[None,10])#占位符,方便用feed_dict进行注入操作
20 21 W = tf.Variable(tf.random_normal([784,10]))#要学习的参数统一用Variable来定义,这样方便进行调整更新 22 b = tf.Variable(tf.zeros([10])) 23 24 25 #construct the model 26 pred = tf.nn.softmax(tf.matmul(x,W) + b) 27 28 cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) 29 30 learning_rate = 0.01 31 32 optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) 33 34 #set parameters about thee model 35 training_epoch = 25 36 batch_size = 100 37 display_step = 1 38 saver = tf.train.Saver() 39 model_path = "log/kerwinsmodel.ckpt" 40 41 #start the session 42 43 with tf.Session() as sess : 44 sess.run(tf.global_variables_initializer()) 45 46 for epoch in range(training_epoch): 47 avg_cost = 0 48 total_batch = int(mnist.train.num_examples/batch_size) 49 print(total_batch) 50 for i in range(total_batch): 51 batch_xs,batch_ys = mnist.train.next_batch(batch_size) 52 53 _,c = sess.run([optimizer,cost],feed_dict=x:batch_xs,y:batch_ys) 54 55 avg_cost += c/ total_batch 56 if (epoch +1 ) % display_step ==0: 57 print("Epoch:",%04d %(epoch+1),"cost=",":.9f".format(avg_cost)) 58 59 print("Finish!") 60 61 correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) 62 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) 63 print("Accuracy:",accuracy.eval(x:mnist.test.images,y:mnist.test.labels)) 64 65 save_path = saver.save(sess,model_path) 66 print("Model saved in file: %s" % save_path) 67 # 68 69 70 #读取模型程序 71 72 print("Starting 2nd session...") 73 with tf.Session() as sess: 74 sess.run(tf.global_variables_initializer()) 75 saver.restore(sess,model_path) 76 77 #测试model 78 correct_prediction = tf.equal(tf.arg_max(pred,1),tf.argmax(y,1)) 79 #计算准确率 80 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) 81 print("Accuracy:",accuracy.eval(x:mnist.test.images,y:mnist.test.labels)) 82 83 output = tf.argmax(pred,1) 84 batch_xs,batch_ys = mnist.train.next_batch(2) 85 outputval,predv = sess.run([output,pred],feed_dict=x:batch_xs) 86 print(outputval,predv,batch_ys) 87 88 im = batch_xs[0] 89 im = im.reshape(-1,28) 90 pylab.imshow(im) 91 pylab.show() 92 93 im = batch_xs[1] 94 im = im.reshape(-1,28) 95 pylab.imshow(im) 96 pylab.show() 97 98 99 100

 

#占位符,方便用feed_dict进行注入操作

tensorflow笔记之mnist手写识别系列一

tensorflow笔记(四)之MNIST手写识别系列一版权声明:本文为博主原创文章,转载请指明转载地址http://www.cnblogs.com/fydeblog/p/7436310.html前言这篇博客将利用神经网络去训练MNIST数据集,通过学习到的模型去分类手写数字。我会将本篇... 查看详情

tensorflow基础学习五:mnist手写数字识别

MNIST数据集介绍:fromtensorflow.examples.tutorials.mnistimportinput_data#载入MNIST数据集,如果指定地址下没有已经下载好的数据,tensorflow会自动下载数据mnist=input_data.read_data_sets(‘.‘,one_hot=True)#打印Trainingdatasize:55000。print("Train 查看详情

tensorflow之mnist手写数字识别:分类问题(代码片段)

整体代码:#数据读取importtensorflowastfimportmatplotlib.pyplotaspltimportnumpyasnpfromtensorflow.examples.tutorials.mnistimportinput_datamnist=input_data.read_data_sets("MNIST_data/",one_hot=True)#定义待输入数据的占位符# 查看详情

77tensorflow手写识别基础版本

‘‘‘Createdon2017年4月20日@author:weizhen‘‘‘#手写识别fromtensorflow.examples.tutorials.mnistimportinput_datamnist=input_data.read_data_sets("/path/to/MNIST_data/",one_hot=True)batch_size=100xs,ys=mnist.train. 查看详情

tensorflow2.0入门教程实战案例

中文文档TensorFlow2/2.0中文文档知乎专栏欢迎关注知乎专栏 https://zhuanlan.zhihu.com/geektutu一、实战教程之强化学习TensorFlow2.0(九)-强化学习70行代码实战PolicyGradientTensorFlow2.0(八)-强化学习DQN玩转gymMountainCarTensorFlow2.0(七)-强化学习Q-Le... 查看详情

tensorflow实践mnist手写数字识别

minst数据集                        tensorflow的文档中就自带了mnist手写数字识别的例子,是一个很经典也比较简单的入门tensorflow的例子,非常值得自己动 查看详情

android+tensorflow+cnn+mnist手写数字识别实现

SyncHere  查看详情

android+tensorflow+cnn+mnist手写数字识别实现

SyncHere  查看详情

基于tensorflow的mnist手写识别

这个例子,是学习tensorflow的人员通常会用到的,也是基本的学习曲线中的一环。我也是! 这个例子很简单,这里,就是简单的说下,不同的tensorflow版本,相关的接口函数,可能会有不一样哟。在TensorFlow的中文介绍文档中的... 查看详情

mnist手写数字识别tensorflow(代码片段)

Mnist手写数字识别Tensorflow任务目标了解mnist数据集搭建和测试模型编辑环境操作系统:Win10python版本:3.6集成开发环境:pycharmtensorflow版本:1.*了解mnist数据集mnist数据集:mnist数据集下载地址??MNIST数据集来自美国国家标准与技术研究所,... 查看详情

使用tensorflow和mnist识别自己手写的数字

#!/usr/bin/envpython3fromtensorflow.examples.tutorials.mnistimportinput_datamnist=input_data.read_data_sets(‘MNIST_data‘,one_hot=True)importtensorflowastfsess=tf.InteractiveSession()x=tf.placeholder(t 查看详情

tensorflow学习笔记mnist手写数字识别(代码片段)

...模型,以达到稳定的高性能,而是学会如何使用Tensorflow解 查看详情

tensorflow基础入门(代码片段)

...      #MNIST手写数字识别数据集importtensorflowastfimporttensorflow.examples.tutorials.mnist.input_dataasinput_dataimportnumpyasnpimportmatplotlib.pyplotaspltmnist=input_data.read_data_sets("MNST_data/",one_hot=True)#了解MNIST手写数字识别数据集print("... 查看详情

tensorflow项目实战一:mnist手写数字识别(代码片段)

...测结果,然后计算损失,进行训练。  代码如下:importtensorflowastffromtensorflow.examples.tutorials.mnistimport 查看详情

tensorflow实战手写数字识别(tensorboard可视化)

一、前言为了更好的理解NeuralNetwork,本文使用Tensorflow实现一个最简单的神经网络,然后使用MNIST数据集进行测试。同时使用Tensorboard对训练过程进行可视化,算是打响学习Tensorflow的第一枪啦。看本文之前,希望你已经具备机器... 查看详情

tensorflow入门实战|第1周:实现mnist手写数字识别(代码片段)

...语言环境:Python3.6.5编译器:jupyternotebook深度学习环境:TensorFlow2文章目录一、前期工作1.设置GPU(如果使用的是CPU可以忽略这步)2.导入数据3.归一化4.可视化图片5.调整图片格式二、构建CNN网络模型三、编译模型四、训练模型五... 查看详情

tensorflow中使用cnn实现mnist手写体识别

  本文参考YannLeCun的LeNet5经典架构,稍加ps得到下面适用于本手写识别的cnn结构,构造一个两层卷积神经网络,神经网络的结构如下图所示:  输入-卷积-pooling-卷积-pooling-全连接层-Dropout-Softmax输出    第一层卷积利用5*... 查看详情

tensorflow入门实战|第1周:实现mnist手写数字识别(代码片段)

...语言环境:Python3.6.5编译器:jupyternotebook深度学习环境:TensorFlow2文章目录一、前期工作1.设置GPU(如果使用的是CPU可以忽略这步)2.导入数据3.归一化4.可视化图片5.调整图片格式二、构建CNN网络模型三、编译模型四、训练模型五... 查看详情