tensorflowmnist手写数字识别之过拟合

双斜杠少年 双斜杠少年     2022-10-01     578

关键词:

1. 过拟合 overfitting 问题

什么是过拟合呢?

用实际生活中的一个例子来比喻一下过拟合现象. 说白了, 就是机器学习模型于自信. 已经到了自负的阶段了. 那自负的坏处, 大家也知道, 就是在自己的小圈子里表现非凡, 不过在现实的大圈子里却往往处处碰壁. 所以在这个简介里, 我们把自负和过拟合画上等号.

学习模型可能太满足了所有的训练数据,所以导致在实际数据中误差陡增,如下图,绿的的线是过拟合的训练结果,黑色的分界线是我门期望的结果。

这里写图片描述

解决过拟合

  • 方法一: 增加数据量, 大部分过拟合产生的原因是因为数据量太少了. 如果我们有成千上万的数据, 红线也会慢慢被拉直, 变得没那么扭曲 .

  • 方法二:运用正规化. L1, l2 regularization等等, 这些方法适用于大多数的机器学习, 包括神经网络. 他们的做法大同小异, 我们简化机器学习的关键公式为 y=Wx . W为机器需要学习到的各种参数. 在过拟合中, W 的值往往变化得特别大或特别小. 为了不让W变化太大, 我们在计算误差上做些手脚. 原始的 cost 误差是这样计算, cost = 预测值-真实值的平方. 如果 W 变得太大, 我们就让 cost 也跟着变大, 变成一种惩罚机制. 所以我们把 W 自己考虑进来. 这里 abs 是绝对值. 这一种形式的 正规化, 叫做 l1 正规化. L2 正规化和 l1 类似, 只是绝对值换成了平方. 其他的l3, l4 也都是换成了立方和4次方等等. 形式类似. 用这些方法,我们就能保证让学出来的线条不会过于扭曲.

  • dropout: 在训练的时候, 我们随机忽略掉一些神经元和神经联结 , 是这个神经网络变得”不完整”. 用一个不完整的神经网络训练一次.

    到第二次再随机忽略另一些, 变成另一个不完整的神经网络. 有了这些随机 drop 掉的规则, 我们可以想象其实每次训练的时候, 我们都让每一次预测结果都不会依赖于其中某部分特定的神经元.

2. 在MNIST中解决过拟合

关于处理数据和构建神经网络 请参照博客 TensorFlow 入门之第一个神经网络和训练 MNIST

代码使用了tensorboard 可视化(关于tensorboard 有时间会专门介绍)来观察 loss 函数的变化,使用 dropout来解决过拟合问题。

下面是主要代码

# -*- coding: utf-8 -*-
# 用 drop out 解决 Overfitting 问题
import 手写数字识别.input_data  
import tensorflow as tf
mnist = 手写数字识别.input_data.read_data_sets("手写数字识别/MNIST_data/", one_hot=True)  

# 添加神经层的函数def add_layer(),它有四个参数:输入值、输入的大小、输出的大小和激励函数,我们设定默认的激励函数是None。也就是线性函数
def add_layer(inputs, in_size, out_size,layer_name, activation_function=None):
    # 定义权重,尽量是一个随机变量
    # 因为在生成初始参数时,随机变量(normal distribution) 会比全部为0要好很多,所以我们这里的weights 是一个 in_size行, out_size列的随机变量矩阵。   
    Weights = tf.Variable(tf.random_normal([in_size, out_size]))
    # 在机器学习中,biases的推荐值不为0,所以我们这里是在0向量的基础上又加了0.1。
    biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
    # 定义Wx_plus_b, 即神经网络未激活的值(预测的值)。其中,tf.matmul()是矩阵的乘法。
    Wx_plus_b = tf.matmul(inputs, Weights) + biases
    #  用dropout来解决过拟合 问题
    Wx_plus_b = tf.nn.dropout(Wx_plus_b, keep_prob)
    # activation_function ——激励函数(激励函数是非线性方程)为None时(线性关系),输出就是当前的预测值——Wx_plus_b,
    # 不为None时,就把Wx_plus_b传到activation_function()函数中得到输出。
    if activation_function is None:
        outputs = Wx_plus_b
    else:
        # 返回输出
        outputs = activation_function(Wx_plus_b)
        tf.summary.histogram(layer_name + '/outputs', outputs)
    return outputs

# 保留数据 keep_prob
keep_prob = tf.placeholder(tf.float32)
xs = tf.placeholder(tf.float32,[None, 784]) #图像输入向量  每个图片有784 (28 *28) 个像素点
ys = tf.placeholder(tf.float32, [None,10]) #每个例子有10 个输出

# prediction = add_layer(xs, 784, 10, activation_function=tf.nn.softmax)
l1 = add_layer(xs, 784, 50, 'l1', activation_function=tf.nn.tanh)
prediction = add_layer(l1, 50, 10,'l2' ,activation_function=tf.nn.softmax)
#loss函数(即最优化目标函数)选用交叉熵函数。交叉熵用来衡量预测值和真实值的相似程度,如果完全相同,它们的交叉熵等于零 ,所以loss 越小 学的好
#分类一般都是 softmax+ cross_entropy
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
reduction_indices=[1]))
tf.summary.scalar('loss', cross_entropy)

#train方法(最优化算法)采用梯度下降法。  优化器 如何让机器学习提升它的准确率。 tf.train.GradientDescentOptimizer()中的值(学习的效率)通常都小于1
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
sess = tf.Session()

merged = tf.summary.merge_all() # tensorflow >= 0.12
train_writer = tf.summary.FileWriter("/Users/yangyibo/test/logs/train",sess.graph)
test_writer = tf.summary.FileWriter("/Users/yangyibo/test/logs/test",sess.graph)

# 初始化变量
init= tf.global_variables_initializer()
sess.run(init)

for i in range(1000):
    #开始train,每次只取100张图片,免得数据太多训练太慢
    batch_xs, batch_ys = mnist.train.next_batch(50)
    # 丢弃百分之40 的数据
    sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: 0.5})
    if i % 50 == 0:
        train_result = sess.run(merged, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: 1})
        test_result = sess.run(merged, feed_dict={xs: mnist.test.images, ys: mnist.test.labels, keep_prob: 1})
        train_writer.add_summary(train_result,i)
        test_writer.add_summary(test_result,i)

本文源码:
https://github.com/527515025/My-TensorFlow-tutorials/blob/master/tensorflow_10_dropout.py
欢迎 start

本文参考:莫烦老师的pyhon课程

吴恩达ml之过拟合问题

       查看详情

模型评估之过拟合和欠拟合

...程中,往往会遇到“过拟合”和欠拟合现象,如何有效地识别“过拟合”和“欠拟合”现象,并有针对性地调整模型,是不断改进机器学习模型的关键。给定一个假设空间F,一个假设f属于F,如果存在其他的假设f′也属于F,使得... 查看详情

手写数字识别基于matlabgui知识库手写数字识别(写字板+图片)含matlab源码1227期

一、手写数字识别技术简介1案例背景手写体数字识别是图像识别学科下的一个分支,是图像处理和模式识别研究领域的重要应用之一,并且具有很强的通用性。由于手写体数字的随意性很大,如笔画粗细、字体大小、倾斜角度... 查看详情

matlab使用cnn拟合回归模型预测手写数字的旋转角度

✅作者简介:热爱科研的算法开发者,Python、Matlab项目可交流、沟通、学习。 查看详情

手写数字识别基于matlabguibp神经网络单个或连续手写数字识别系统含matlab源码2296期

⛄一、手写数字识别技术简介1案例背景手写体数字识别是图像识别学科下的一个分支,是图像处理和模式识别研究领域的重要应用之一,并且具有很强的通用性。由于手写体数字的随意性很大,如笔画粗细、字体大... 查看详情

手写数字识别基于matlabguibp神经网络单个或连续手写数字识别系统含matlab源码2296期

⛄一、手写数字识别技术简介1案例背景手写体数字识别是图像识别学科下的一个分支,是图像处理和模式识别研究领域的重要应用之一,并且具有很强的通用性。由于手写体数字的随意性很大,如笔画粗细、字体大... 查看详情

手写数字识别基于matlabgui欧拉数和二维矩阵相关系数手写数字识别含matlab源码1896期(代码片段)

一、手写数字识别简介1引言数字识别技术是图像处理领域中的一个研究热点,在食品、化妆品、药品等外包装生产日期提取上具有重要的实用价值。近年来,随着人们对数字图像识别算法的不断研究,数字图像识别方法也越来越多,... 查看详情

keras入门实战:手写数字识别

...。本文用深度学习Python库Keras实现深度学习入门教程mnist手写数字识别。mnist手写数字识别是机器学习和深度学习领域的“helloworld”,MNIST数据集是手写数字的数据集合,训练集规模为60000,测试集为100 查看详情

手写数字识别基于matlabguibp神经网络手写数字识别系统含matlab源码1639期(代码片段)

一、手写数字识别技术简介(附lunwen)1案例背景手写体数字识别是图像识别学科下的一个分支,是图像处理和模式识别研究领域的重要应用之一,并且具有很强的通用性。由于手写体数字的随意性很大,如笔... 查看详情

pytorch实现手写数字识别(代码片段)

PyTorch实现手写数字识别使用Pytorch实现手写数字识别1.思路和流程分析2.准备训练集和测试集2.1torchvision.transforms的图形数据处理方法2.1.1torchvision.transforms.ToTensor2.1.2torchvision.transforms.Normalize(mean,std)2.1.3torchvision.transforms.Com 查看详情

opencv——识别手写体数字

这个是树莓派上运行的,opencv3opencv提供了一张手写数字图片给我们,如下图所示,可以作为识别手写数字的样本库。0到9共十个数字,每个数字有五行,一行100个数字。首先要把这5000个数字截取出来。图片大小为1000*2000,则每个... 查看详情

手写数字识别基于matlabfisher分类手写数字识别含matlab源码505期(代码片段)

一、简介Fisher分类器用于解决二类线性可分问题。Fisher准则基本原理:找到一个最合适的投影轴,使两类样本在该轴上投影之间的距离尽可能远,而每一类样本的投影尽可能紧凑,从而使分类效果为最佳。例如上图中:通过将方... 查看详情

knn分类算法实现手写数字识别

需求:利用一个手写数字“先验数据”集,使用knn算法来实现对手写数字的自动识别;先验数据(训练数据)集:?数据维度比较大,样本数比较多。? 数据集包括数字0-9的手写体。?每个数字大约有200个样本。?每个样本保持... 查看详情

手写数字识别基于matlabguibp神经网络手写数字识别含matlab源码1118期(代码片段)

一、简介1概述BP(BackPropagation)神经网络是1986年由Rumelhart和McCelland为首的科研小组提出,参见他们发表在Nature上的论文Learningrepresentationsbyback-propagatingerrors。BP神经网络是一种按误差逆传播算法训练的多层前馈网络... 查看详情

python,opencv使用knn来构建手写数字及字母识别ocr(代码片段)

Python,OpenCV使用KNN来构建手写数字及字母识别OCR1.原理1.1手写数字识别1.2字母识别2.源码2.1手写数字OCR2.2字母OCR参考这篇博客将介绍如何借助OpenCV提供的手写数字及字母数据集,来构建训练KNN模型,以进行手写数字及... 查看详情

手写数字识别基于matlabcnn网络手写数字识别分类含matlab源码1286期

一、CNN简介1机器如何识图先给大家出个脑筋急转弯:在白纸上画出一个大熊猫,一共需要几种颜色的画笔?——大家应该都知道,只需要一种黑色的画笔,只需要将大熊猫黑色的地方涂上黑色,一个大熊猫的图像就可以展现出... 查看详情

图像识别基于模板匹配之手写数字识别系统gui界面(代码片段)

...问题。光学字符识别的出现为这一问题提供了解决方法。手写体数字识别是光学字符识别的重要分支,因其在金融、邮政、医疗、交通、教育等领域中广泛的应用而日益被重视。目前,已有多种手写体数字识别算法,但都很难满足手... 查看详情

基于tensorflow的手写数字识别代码(代码片段)

基于tensorflow的手写数字识别代码fromkeras.utilsimportto_categoricalfromkerasimportmodels,layers,regularizersfromkeras.optimizersimportRMSpropfromkeras.datasetsimportmnist(train_images,train_labels),(test_images, 查看详情