pytorch使用mnist数据集实现手写数字识别(代码片段)

1-0001 1-0001     2022-12-01     478

关键词:

使用mnist数据集实现手写数字识别是入门必做吧。这里使用pyTorch框架进行简单神经网络的搭建。

首先导入需要的包。

1 import torch
2 import torch.nn as nn
3 import torch.utils.data as Data
4 import torchvision

 

接下来需要下载mnist数据集。我们创建train_data。使用torchvision.datasets.MNIST进行数据集的下载。

1 train_data = torchvision.datasets.MNIST(
2     root=./mnist/,   #下载到该目录下
3     train=True,                                     #为训练数据
4     transform=torchvision.transforms.ToTensor(),    #将其装换为tensor的形式
5     download=True, #第一次设置为true表示下载,下载完成后,将其置成false
6 )

 之后将其导入data_loader中,这个数据加载类会自动帮我们进行数据集的切片。

 1 train_data = torchvision.datasets.MNIST(
 2     root=./mnist,
 3     train=True,
 4     transform=torchvision.transforms.ToTensor(),
 5     download=False
 6 )
 7 train_loader = Data.DataLoader(dataset=train_data, batch_size=32, shuffle=True, num_workers=0)
 8 test_data = torchvision.datasets.MNIST(
 9     root=./mnist,
10     train=False,
11     transform=torchvision.transforms.ToTensor(),
12 )
13 test_loader = Data.DataLoader(dataset=test_data, batch_size=32, shuffle=False, num_workers=0)
14 test_num = len(test_data)

之后开始定义我们的模型,由于minist数据集是灰度图像,并且图片的size都是(28, 28, 1),所以输入图片的时候不需要进行额外的修改。

 1 class Net(nn.Module):
 2     def __init__(self):
 3         super(Net, self).__init__()
 4         self.conv1 = nn.Sequential(#(1, 28, 28)
 5             nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),#(16, 28, 28)
 6             nn.ReLU(),#(16, 28, 28)
 7             nn.MaxPool2d(kernel_size=2)#(16, 14, 14)
 8         )
 9         self.conv2 = nn.Sequential(
10             nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),#(32, 14, 14)
11             nn.ReLU(),#(32, 14, 14)
12             nn.MaxPool2d(kernel_size=2)#(32, 7, 7)
13         )
14         self.fc = nn.Linear(32 * 7 * 7, 10)
15     def forward(self, x):
16         x = self.conv1(x)
17         x = self.conv2(x)
18         x = x.view(x.size(0), -1)
19         x = self.fc(x)
20         return x

特别注意在最后传入全连接层时,最好自己将x的size改变以确保不会因为自适应而造成错误。因为在传入全连接层时会默认压缩成二维,例如[1, 2, 3, 4]会被压缩成[1*2, 3*4]。

之后开始训练。

 1 net = Net()
 2 loss_fn = nn.CrossEntropyLoss()
 3 optim = torch.optim.Adam(net.parameters(), lr = 0.001)
 4 
 5 save_path = ./mnist.pth
 6 best_acc = 0.0
 7 for epoch in range(3):
 8 
 9     net.train()
10     running_loss = 0.0
11     for step, data in enumerate(train_loader, start=0):
12         images, labels = data
13         optim.zero_grad()
14         logits = net(images)
15         loss = loss_fn(logits, labels)
16         loss.backward()
17         optim.step()
18 
19 
20         running_loss += loss.item()
21         rate = (step+1)/len(train_loader)
22         a = "*" * int(rate * 50)
23         b = "." * int((1 - rate) * 50)
24         print("
train loss: :^3.0f%[->]:.4f".format(int(rate*100), a, b, loss), end="")
25     print()
26 
27     net.eval()
28     acc = 0.0
29     with torch.no_grad():
30         for data_test in test_loader:
31             test_images, test_labels = data_test
32             outputs = net(test_images)
33             predict_y = torch.max(outputs, dim=1)[1]#torch.max返回两个数值,一个是最大值,一个是最大值的下标
34             acc += (predict_y == test_labels).sum().item()
35         test_accurate = acc / test_num
36         if test_accurate > best_acc:
37             best_acc = test_accurate
38             torch.save(net.state_dict(), save_path)
39         print([epoch %d] train_loss: %.3f  test_accuracy: %.3f %
40               (epoch + 1, running_loss / step, test_accurate))
41 
42 print(Finished Training)

在完成训练后,训练的权重会保存在所设置路径下的文件中,进行预测的时候,建立模型,载入权重,照一张数字的图片,对其进行裁剪,灰度等操作之后加载入模型进行预测。

 1 from PIL import Image
 2 import  matplotlib.pyplot as plt
 3 from torchvision import transforms
 4 import torch
 5 from model import Net
 6 
 7 img = Image.open("./YLY2@8UMGLW37S$)NCVZ23.png")
 8 
 9 plt.imshow(img)
10 
11 # [N, C, H, W]
12 
13 train_transform = transforms.Compose([
14         transforms.Grayscale(),
15         transforms.Resize((28, 28)),
16         transforms.ToTensor(),
17 ])
18 
19 img = train_transform(img)
20 # expand batch dimension
21 img = torch.unsqueeze(img, dim=0)
22 
23 # create model
24 model = Net()
25 # load model weights
26 model_weight_path = "./mnist.pth"
27 model.load_state_dict(torch.load(model_weight_path))
28 
29 index_to_class = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
30 
31 
32 model.eval()
33 with torch.no_grad():
34     # predict class
35     y = model(img)
36     #print(y.size())
37     output = torch.squeeze(y)
38     #print(output)
39     predict = torch.softmax(output, dim=0)
40     #print(predict)
41     predict_cla = torch.argmax(predict).numpy()
42     #print(predict_cla)
43 print(index_to_class[predict_cla], predict[predict_cla].numpy())
44 plt.show()

需要注意的是,载入模型的图片必须多一个维度batch,所以我们用img = torch.unsqueeze(img, dim=0)在图片的开头增加一个batch维度。

之后载入图片,得到输出,将输出的batch维度压缩掉,使用softmax函数得到概率分布,再用argmax函数得到最大值的下标,打印最大值所对应的类别及其概率。

pytorch基于cnn的手写数字识别(在mnist数据集上训练)(代码片段)

最终成果http://pytorch-cnn-mnist.herokuapp.com/GITHUBhttps://github.com/XavierJiezou/pytorch-cnn-mnist本文以最经典的mnist数据集为例,讲述了使用pytorch做机器学习的一整套流程,文中所提到的所有代码都可以到github中查看。项目场景简单的... 查看详情

pycharm-mnist手写数字的识别

本文是将PyTorch上的MNIST手写数字的识别、CNN网络在PyCharm上进行展示可以看我这篇博客PyTorch-CNNModel关于MNIST数据集无法下载的问题可以看我另一篇博客PyTorch-MNIST数据集无法下载-采用手动下载的方式在PyCharm中导入PyTorc转载请注明... 查看详情

基于mnist数据集实现手写数字识别(代码片段)

介绍在TensorFlow的官方入门课程中,多次用到mnist数据集。mnist数据集是一个数字手写体图片库,但它的存储格式并非常见的图片格式,所有的图片都集中保存在四个扩展名为idx*-ubyte.gz的二进制文件。可以直接从官网进... 查看详情

pytorch入门实战|第p1周:实现mnist手写数字识别

查看详情

365计划-1pytorch实现mnist手写数字识别(代码片段)

🍨本文为🔗365天深度学习训练营中的学习记录博客🍦参考文章地址:365天深度学习训练营-第P1周:mnist手写数字识别🍖作者:K同学啊###本项目来自K同学在线指导###importtorch.nnasnnimportmatplotlib.pyplotasplti... 查看详情

pytorch实现手写数字识别(代码片段)

PyTorch实现手写数字识别使用Pytorch实现手写数字识别1.思路和流程分析2.准备训练集和测试集2.1torchvision.transforms的图形数据处理方法2.1.1torchvision.transforms.ToTensor2.1.2torchvision.transforms.Normalize(mean,std)2.1.3torchvision.transforms.Com 查看详情

我用pytorch复现了lenet-5神经网络(mnist手写数据集篇)!

...介绍了卷积神经网络LeNet-5的理论部分。今天我们将使用Pytorch来实现LeNet-5模型,并用它来解决MNIST数据集的识别。正文开始!一、使用LeNet-5网络结构创建MNIST手写数字识别分类器MNIST是一个非常有名的手写体数字识别数据... 查看详情

我用pytorch复现了lenet-5神经网络(mnist手写数据集篇)!

...介绍了卷积神经网络LeNet-5的理论部分。今天我们将使用Pytorch来实现LeNet-5模型,并用它来解决MNIST数据集的识别。正文开始!一、使用LeNet-5网络结构创建MNIST手写数字识别分类器MNIST是一个非常有名的手写体数字识别数据... 查看详情

我用pytorch复现了lenet-5神经网络(mnist手写数据集篇)!

...介绍了卷积神经网络LeNet-5的理论部分。今天我们将使用Pytorch来实现LeNet-5模型,并用它来解决MNIST数据集的识别。正文开始!一、使用LeNet-5网络结构创建MNIST手写数字识别分类器MNIST是一个非常有名的手写体数字识别数据... 查看详情

使用pytorch实现手写数字识别(代码片段)

使用Pytor##标题ch实现手写数字识别思路和流程分析准备数据,这些需要准备DataLoader构建模型,这里可以使用torch构造一个深层的神经网络模型的训练模型的保存,保存模型,后续持续使用模型的评估,使用测... 查看详情

keras实现mnist数据集手写数字识别(代码片段)

一.Tensorflow环境的安装这里我们只讲CPU版本,使用Anaconda进行安装a.首先我们要安装Anaconda链接:https://pan.baidu.com/s/1AxdGi93oN9kXCLdyxOMnRA密码:79ig过程如下:第一步:点击next第二步:IAgree第三步:JustME第四步:自己选择一个恰当位置... 查看详情

pytorch手写数字识别项目增量式训练

dataset.py ‘‘‘准备数据集‘‘‘importtorchfromtorch.utils.dataimportDataLoaderfromtorchvision.datasetsimportMNISTfromtorchvision.transformsimportToTensor,Compose,Normalizeimporttorchvisionimportconfigdefmnist_dataset(train):func=torchvision.transforms.Compose([torchvision.transform... 查看详情

[pytorch系列-34]:卷积神经网络-搭建lenet-5网络与mnist数据集手写数字识别(代码片段)

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121050469目录前言:LeNet网络详解第1章业务领域分析1.1 步骤1-1:... 查看详情

图像分类基于pytorch搭建lstm实现mnist手写数字体识别(单向lstm,附完整代码和数据集)(代码片段)

写在前面:首先感谢兄弟们的关注和订阅,让我有创作的动力,在创作过程我会尽最大能力,保证作品的质量,如果有问题,可以私信我,让我们携手共进,共创辉煌。提起LSTM大家第一反应是在NLP的数据集上比较常见,不过在... 查看详情

java实现bp神经网络mnist手写数字识别(代码片段)

...含服务器和客户端程序,可在服务器训练后使客户端直接使用训练结果,界面有画板,可以手写数字Java实现BP神经网络MNIST手写数字识别如果需要源码,请在下方评论区留下邮箱,我看到就会发过去一、神经网络的构建(1):构建... 查看详情

基于tensorflow2.x实现bp神经网络,实践mnist手写数字识别(代码片段)

...下,如果没有指定数据集的位置,并先前也没有使用过,会自动联网下载该,使该数据集使用起来更加方便,它包括了70000张图 查看详情

基于tensorflow2.x实现bp神经网络,实践mnist手写数字识别(代码片段)

...下,如果没有指定数据集的位置,并先前也没有使用过,会自动联网下载该,使该数据集使用起来更加方便,它包括了70000张图 查看详情

全连接神经网络实现识别手写数据集mnist(代码片段)

全连接神经网络实现识别手写数据集MNISTMNIST是一个由美国由美国邮政系统开发的手写数字识别数据集。手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载。总共4个文件,该文件是二进制内容。train-images-idx3-uby... 查看详情