pytorch实现rnn网络对mnist字体分类(代码片段)

城南皮卡丘 城南皮卡丘     2023-04-01     309

关键词:

 我们知道,循环神经网络RNN非常擅长处理序列数据,但它也可以用来处理图像数据,这是因为一张图像可以看作一组由很长的像素点组成的序列。下面将会使用RNN对MNIST数据集建立分类器。

目录

1.准备数据集、定义数据加载器

2.搭建RNN网络 

3.RNN网络的训练与预测


1.准备数据集、定义数据加载器

在进行数据准备工作时,可以直接从torchvision库的datasets模块导入MNIST手写字体的训练数据集和测试数据集,然后使用Data.DataLoader()函数将两个数据集定义为数据加载器,其中每个batch包含64张图像,最后得到训练集数据加载器train_loader与测试集数据加载器test_loader。在导入的数据集中,训练集包含60000张28×28的灰度图像,测试集包含10000张28×28的灰度图像。

#准备训练数据集
train_data=torchvision.datasets.MNIST(
    root="../Dataset",
    train=True,
    transform=transforms.ToTensor(),
    download=False
)
#定义一个数据加载器
train_loader=Data.DataLoader(
    dataset=train_data,
    batch_size=64,
    shuffle=True,
    num_workers=0
)
#准备测试数据集
test_data=torchvision.datasets.MNIST(
    root="../Dataset",
    train=False,
    transform=transforms.ToTensor(),
    download=False
)
#定义测试数据集的数据加载器
test_loader=Data.DataLoader(
    dataset=test_data,
    batch_size=64,
    shuffle=True,
    num_workers=0
)

2.搭建RNN网络 

下面程序定义了MyRnnMnistNet类,在调用时需要输入4个参数,参数input_dim表示输入数据的维度,针对图像分类器,其值是图片中每行的数据像素点量,针对手写字体数据其值等于28;参数hidden_dim表示构建RNN网络层中包含神经元的个数;参数layer_dim表示在RNN网络层中有多少层RNN神经元;参数output_dim则表示在使用全连接层进行分类时输出的维度,可以使用数据的类别数表示,MNIST数据集表示为10

#定义网络
class MyRnnMnistNet(nn.Module):
    def __init__(self,input_dim,hidden_dim,layer_dim,output_dim):
        # input_dim:输入数据的维度(图片每行的数据像素点)
        #hidden_dim:RNN神经元个数
        #layer_dim:RNN层数
        #output_dim:隐藏层输出的维度(分类的数量)
        super(MyRnnMnistNet,self).__init__()
        self.hidden_dim=hidden_dim
        self.layer_dim=layer_dim
        self.rnn=nn.RNN(input_dim,hidden_dim,layer_dim,batch_first=True,nonlinearity='relu')
        #连接全连接层
        self.fc1=nn.Linear(hidden_dim,output_dim)
    def forward(self,x):
        # x:[batch,time_step,input_dim]
        #time_step=图像所有像素数量 / input_size
        #out:[batch,time_step,output_size]
        out,h_n=self.rnn(x,None)#None表示h0会使用全0进行初始化
        #选取最后一个时间点的out输出
        out=self.fc1(out[:,-1,:])
        return out
#模型的调用
input_dim=28#图片每行的像素数量
hidden_dim=128#RNN神经元个数
layer_dim=1#RNN层数
output_dim=10#隐藏层输出的维度(10类图像)
MyRnnMnistNet=MyRnnMnistNet(input_dim,hidden_dim,layer_dim,output_dim)
#print(MyRnnMnistNet)

在MyRnnMnistNet类调用nn.RNN()函数时,参数batch_first=True表示使用的数据集中batch在数据的第一个维度,参数nonlinearity='relu'表示RNN层使用的激活函数为ReLU函数。从RNNimc类的forward()函数中,发现网络的self.rnn()层的输入有两个参数,第一个参数为需要分析的数据x,第二个参数则为初始的隐藏层输出,这里使用None代替,表示使用全0进行初始化。而输出则包含两个参数,其中out表示RNN最后一层的输出特征,h_n表示隐藏层的输出。在将RNN层和全连接分类层连接时,将全连接层网络作用于最后一个时间点的out输出。

3.RNN网络的训练与预测

对定义好的网络模型使用训练集进行训练,需要定义优化器和损失函数,优化器使用torch.optim.RMSprop()定义,损失函数则使用交叉嫡损失nn.CrossEntropyLoss()函数定义,并且使用训练集对网络训练30个epoch 

#对模型进行训练
optimizer=torch.optim.RMSprop(MyRnnMnistNet.parameters(),lr=0.0003)
criterion=nn.CrossEntropyLoss()#损失函数
train_loss_all=[]
train_acc_all=[]
test_loss_all=[]
test_acc_all=[]
num_epochs=30
for epoch in range(num_epochs):
    print("Epoch  / ".format(epoch,num_epochs-1))
    MyRnnMnistNet.train()
    corrects=0
    train_num=0
    for step,(b_x,b_y) in enumerate(train_loader):
        xdata=b_x.view(-1,28,28)
        output=MyRnnMnistNet(xdata)
        pre_lab=torch.argmax(output,1)
        loss=criterion(output,b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss+=loss.item() * b_x.size(0)
        corrects +=torch.sum(pre_lab == b_y.data)
        train_num += b_x.size(0)
    #经过一个计算epoch的训练后在训练集上的损失和精度
    train_loss_all.append((loss / train_num).detach().numpy())
    train_acc_all.append(corrects.double().item()/train_num)
    print('Train Loss::.4f Train Acc::.4f'.format(epoch,train_loss_all[-1],train_acc_all[-1]))
    #设置模型为验证模式
    MyRnnMnistNet.eval()
    corrects=0
    test_num=0
    for step,(b_x,b_y) in enumerate(test_loader):
        xdata=b_x.view(-1,28,28)
        output=MyRnnMnistNet(xdata)
        pre_lab=torch.argmax(output,1)
        loss=criterion(output,b_y)
        loss +=loss.item() * b_x.size(0)
        corrects +=torch.sum(pre_lab == b_y.data)
        test_num +=b_x.size(0)
    #计算经过一个epoch的训练后在测试集上的损失和精度
    test_loss_all.append((loss / test_num).detach().numpy())
    test_acc_all.append(corrects.double().item() /test_num)
    print("Test Loss::.4f Test Acc::.4f".format(epoch,test_loss_all[-1],test_acc_all[-1]))

在上面的程序中,每使用训练集对网络进行一轮训练,都会使用测试集来测试当前网络的分类效果,针对训练集和测试集的每个epoch损失和预测精度,都保存在train_loss_all、train_acc_all、test_loss_all、test_acc_all四个列表中。在每个epoch中使用MyRNNimc.train()将网络切换为训练模式,使用MyRnnMnistNet.eval()将网络切换为验证模式。针对网络的输入x,需要从图像数据集[batch, channel, height,width]转化为[batch, time_step, input_dim],即从[64,1,28,28]转化为[64,28,28]。在网络训练完毕后,将网络在训练集和测试集上的损失及预测精度使用折线图可视化,得到的图像如下图所示

#可视化模型训练过程
plt.figure(figsize=(14,5))
plt.subplot(1,2,1)
plt.plot(train_loss_all,"ro-",label="Train Loss")
plt.plot(test_loss_all,"bs-",label="Val Loss")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("Loss")
plt.subplot(1,2,2)
plt.plot(train_acc_all,"ro-",label="Train acc")
plt.plot(test_acc_all,"bs-",label="Val acc")
plt.xlabel("epoch")
plt.ylabel("acc")
plt.legend()
plt.show()

 上图中左图所示为每个epoch在训练集和测试集上的损失函数的变化情况,右图所示则为每个epoch在训练集和测试集上的预测精度的变化情况,精度最终稳定在0.97附近。可见使用RNN网络也能很好地对手写字体图像数据进行预测。

4.该案例的完整代码

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import copy
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torch.utils.data as Data
from  torchvision import transforms

#准备训练数据集
train_data=torchvision.datasets.MNIST(
    root="../Dataset",
    train=True,
    transform=transforms.ToTensor(),
    download=False
)
#定义一个数据加载器
train_loader=Data.DataLoader(
    dataset=train_data,
    batch_size=64,
    shuffle=True,
    num_workers=0
)
#准备测试数据集
test_data=torchvision.datasets.MNIST(
    root="../Dataset",
    train=False,
    transform=transforms.ToTensor(),
    download=False
)
#定义测试数据集的数据加载器
test_loader=Data.DataLoader(
    dataset=test_data,
    batch_size=64,
    shuffle=True,
    num_workers=0
)


#定义网络
class MyRnnMnistNet(nn.Module):
    def __init__(self,input_dim,hidden_dim,layer_dim,output_dim):
        # input_dim:输入数据的维度(图片每行的数据像素点)
        #hidden_dim:RNN神经元个数
        #layer_dim:RNN层数
        #output_dim:隐藏层输出的维度(分类的数量)
        super(MyRnnMnistNet,self).__init__()
        self.hidden_dim=hidden_dim
        self.layer_dim=layer_dim
        self.rnn=nn.RNN(input_dim,hidden_dim,layer_dim,batch_first=True,nonlinearity='relu')
        #连接全连接层
        self.fc1=nn.Linear(hidden_dim,output_dim)
    def forward(self,x):
        # x:[batch,time_step,input_dim]
        #time_step=图像所有像素数量 / input_size
        #out:[batch,time_step,output_size]
        out,h_n=self.rnn(x,None)#None表示h0会使用全0进行初始化
        #选取最后一个时间点的out输出
        out=self.fc1(out[:,-1,:])
        return out
#模型的调用
input_dim=28#图片每行的像素数量
hidden_dim=128#RNN神经元个数
layer_dim=1#RNN层数
output_dim=10#隐藏层输出的维度(10类图像)
MyRnnMnistNet=MyRnnMnistNet(input_dim,hidden_dim,layer_dim,output_dim)
#print(MyRnnMnistNet)
#可视化神经网络


#对模型进行训练
optimizer=torch.optim.RMSprop(MyRnnMnistNet.parameters(),lr=0.0003)
criterion=nn.CrossEntropyLoss()#损失函数
train_loss_all=[]
train_acc_all=[]
test_loss_all=[]
test_acc_all=[]
num_epochs=30
for epoch in range(num_epochs):
    print("Epoch  / ".format(epoch,num_epochs-1))
    MyRnnMnistNet.train()
    corrects=0
    train_num=0
    for step,(b_x,b_y) in enumerate(train_loader):
        xdata=b_x.view(-1,28,28)
        output=MyRnnMnistNet(xdata)
        pre_lab=torch.argmax(output,1)
        loss=criterion(output,b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss+=loss.item() * b_x.size(0)
        corrects +=torch.sum(pre_lab == b_y.data)
        train_num += b_x.size(0)
    #经过一个计算epoch的训练后在训练集上的损失和精度
    train_loss_all.append((loss / train_num).detach().numpy())
    train_acc_all.append(corrects.double().item()/train_num)
    print('Train Loss::.4f Train Acc::.4f'.format(epoch,train_loss_all[-1],train_acc_all[-1]))
    #设置模型为验证模式
    MyRnnMnistNet.eval()
    corrects=0
    test_num=0
    for step,(b_x,b_y) in enumerate(test_loader):
        xdata=b_x.view(-1,28,28)
        output=MyRnnMnistNet(xdata)
        pre_lab=torch.argmax(output,1)
        loss=criterion(output,b_y)
        loss +=loss.item() * b_x.size(0)
        corrects +=torch.sum(pre_lab == b_y.data)
        test_num +=b_x.size(0)
    #计算经过一个epoch的训练后在测试集上的损失和精度
    test_loss_all.append((loss / test_num).detach().numpy())
    test_acc_all.append(corrects.double().item() /test_num)
    print("Test Loss::.4f Test Acc::.4f".format(epoch,test_loss_all[-1],test_acc_all[-1]))
#可视化模型训练过程
plt.figure(figsize=(14,5))
plt.subplot(1,2,1)
plt.plot(train_loss_all,"ro-",label="Train Loss")
plt.plot(test_loss_all,"bs-",label="Val Loss")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("Loss")
plt.subplot(1,2,2)
plt.plot(train_acc_all,"ro-",label="Train acc")
plt.plot(test_acc_all,"bs-",label="Val acc")
plt.xlabel("epoch")
plt.ylabel("acc")
plt.legend()
plt.show()

pytorch学习-7:rnn循环神经网络(分类)(代码片段)

pytorch学习-7:RNN循环神经网络(分类)1.加载MNIST手写数据1.1数据预处理2.RNN模型建立3.训练4.预测参考循环神经网络让神经网络有了记忆,对于序列话的数据,循环神经网络能达到更好的效果.1.加载MNIST手写数据importtorchfr... 查看详情

基于pytorch平台实现对mnist数据集的分类分析(前馈神经网络softmax)基础版(代码片段)

基于pytorch平台实现对MNIST数据集的分类分析(前馈神经网络、softmax)基础版文章目录基于pytorch平台实现对MNIST数据集的分类分析(前馈神经网络、softmax)基础版前言一、基于“前馈神经网络”模型,分类分析... 查看详情

Tensorflow.keras:RNN 对 Mnist 进行分类

】Tensorflow.keras:RNN对Mnist进行分类【英文标题】:Tensorflow.keras:RNNtoclassifyMnist【发布时间】:2020-11-1306:04:08【问题描述】:我试图通过构建一个简单的数字分类器来理解tensorflow.keras.layers.SimpleRNN。Mnist数据集的数字大小为28X28。所... 查看详情

pytorch之神经网络mnist分类任务(代码片段)

...使用TensorDataset和DataLoader简化本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052一、Mnist分类任务简介在上一篇博客当中,我们通过搭建PyTorch神经网络实现了气温预测,这本质上是一个回归任务。在... 查看详情

使用 pytorch 和 sklearn 对 MNIST 数据集进行交叉验证

】使用pytorch和sklearn对MNIST数据集进行交叉验证【英文标题】:CrossvalidationforMNISTdatasetwithpytorchandsklearn【发布时间】:2020-03-1815:27:22【问题描述】:我是pytorch的新手,正在尝试实现前馈神经网络来对mnist数据集进行分类。我在尝... 查看详情

用pytorch实现多层感知机(mlp)(全连接神经网络fc)分类mnist手写数字体的识别(代码片段)

1.导入必备的包1importtorch2importnumpyasnp3fromtorchvision.datasetsimportmnist4fromtorchimportnn5fromtorch.autogradimportVariable6importmatplotlib.pyplotasplt7importtorch.nn.functionalasF8fromtorch.utils.da 查看详情

循环神经网络rnn从零开始实现动手学深度学习v2pytorch

1.循环神经网络RNN2.从零开始实现3.简洁实现4.Q&AGPT3,BERT都是基于Transformer模型的改进,目前也是最火的。voice和image融合算法,用多模态模型。比如自动驾驶领域的运用。RNN批量大小乘以时间长度t,相当于做t次分类... 查看详情

利用knnsvmcnn逻辑回归mlprnn等方法实现mnist数据集分类(pytorch实现)(代码片段)

电脑配置:python3.6*​*Pytorch1.2.0*​*torchvision0.4.0想学习机器学习和深度学习的同学,首先找个比较经典的案例和经典的方法自己动手试一试,分析这些方法的思想和每一行代码是一个快速入门的小技巧,今天我们... 查看详情

Pytorch MNIST 自动编码器学习 10 位分类

】PytorchMNIST自动编码器学习10位分类【英文标题】:PytorchMNISTautoencodertolearn10-digitclassification【发布时间】:2021-06-1411:42:28【问题描述】:我正在尝试为MNIST构建一个简单的自动编码器,其中中间层只有10个神经元。我希望它能够... 查看详情

lenet模型对cifar-10数据集分类pytorch(代码片段)

LeNet模型对CIFAR-10数据集分类【pytorch】目录LeNet网络模型CIFAR-10数据集Pytorch实现代码目录本文为针对CIFAR-10数据集的基于简单神经网络LeNet分类实现(pytorch实现)LeNet网络模型Tip:(以上为原始LeNet)为了更好的... 查看详情

mnist数据集手写体识别(rnn实现)(代码片段)

...络模型(mnist手写体识别数据集)MNIST数据集手写体识别(CNN实现)importtensorflowastfimporttensorflow.examples.tutorials.mnist.input_dataasinput_data#导入 查看详情

pytorch实现简单的分类器(代码片段)

作为目前越来越受欢迎的深度学习框架,pytorch基本上成了新人进入深度学习领域最常用的框架。相比于TensorFlow,pytorch更易学,更快上手,也可以更容易的实现自己想要的demo。今天的文章就从pytorch的基础开始,帮助大家实现成... 查看详情

pytorchnote25深层神经网络实现mnist手写数字分类(代码片段)

PytorchNote25深层神经网络实现MNIST手写数字分类文章目录PytorchNote25深层神经网络实现MNIST手写数字分类MNIST数据集多分类问题softmax交叉熵多层全连接神经网络实现MINST手写数字分类数据预处理简单的四层全连接神经网络定义loss函数... 查看详情

神经网络的学习-搭建神经网络实现mnist数据集分类(代码片段)

...失函数2.损失函数的意义3.数值微分4.梯度法5.学习算法的实现四、神经网络的学习这一章通过两层神经网络实现对mnist手写数据集的识别,本文是源于《深度学习入门》的学习笔记若理解困难,参考上一章笔记:深度... 查看详情

pytorch笔记-recurrentneuralnetwork(rnn)循环神经网络(代码片段)

循环神经网络,RNN(RecurrentNeuralNetwork):记忆单元分类:RNN(RecurrentNeuralNetwork)、GRU(GateRecurrentUnit)、LSTM(LongShort-TermMemory)模型类别:单向循环、双向循环、多层单向或双向叠加优缺点:优点:可以处理变长序列 查看详情

基于pytorch后量化(mnist分类)---浮点训练vs多bit后量化vs多bit量化感知训练效果对比

基于pytorch后量化(mnist分类)—浮点训练vs多bit后量化vs多bit量化感知训练效果对比代码下载地址:下载地址试了bit数为1~8的准确率,得到下面这张折线图:发现,当bit>=3的时候,精度几乎不会掉,bit=2的时候精度下降到69%,b... 查看详情

pytorchnote38rnn做图像分类(代码片段)

...能想CNN一样用来做图像分类呢?下面我们用mnist手写字体的例子来展示一下如何用RNN做图像分类,但是这种方法并不是主流,这里 查看详情

图像分类基于pytorch搭建lstm实现mnist手写数字体识别(双向lstm,附完整代码和数据集)(代码片段)

...sdn.net/AugustMe/article/details/128969138文章中,我们使用了基于PyTorch搭建LSTM实现MNIST手写数字体识别,LSTM是单向的,现在我们使用双向LSTM试一试效果,和之前的单向LSTM模型稍微有差别,请注意查看代码的变化。1.导入依赖库这些依赖... 查看详情