[基于pytorch的mnist识别02]用户数据集的读取(代码片段)

AIplusX AIplusX     2023-01-10     627

关键词:

写在前面

pytorch包含了很多包括mnist在内的开源数据集,但是如果要建立自己的神经网络的话肯定需要训练自己的数据集,那么如何利用pytorch加载用户自己的数据集呢?今天就来解决这个问题。

今天的工作

需要加载用户自己的数据需要继承pytorch的Dataset类,并且重载其中的__getitem__方法和__len__方法,前者是告诉pytorch如何根据索引来获取对应的样例,后者是告诉pytorch数据集的大小,在这里先放一下我的继承实现:

class UserMNIST(Dataset):
    def __init__(self, root='', train=True, transform=None, target_transform=None):
        super(UserMNIST, self).__init__()
        self.train = train #type of datasets
        self.transform = transform
        self.target_transform = target_transform
        self.train_nums = int(6e4)
        self.test_ratio = int(9e-1)
        self.validate_nums = int(1e4)
#         print(self.train_nums)#The scientific counting method is float,which should be changed by user
        
        #load files path
        if self.train :
            self.imgs_folder_path = train_imgs_path
            self.labels_folder_path = train_labels_path
            self.img_nums = self.train_nums
        else:
            self.imgs_folder_path = validate_imgs_path
            self.labels_folder_path = validate_labels_path
            self.img_nums = self.validate_nums
        
        #load dataset
        with open(self.imgs_folder_path, 'rb') as _imgs:
            self._train_images = _imgs.read()
        with open(self.labels_folder_path, 'rb') as _labs:
            self._train_labels = _labs.read()
            
            
    def __getitem__(self, index):
        image = self.getImages(self._train_images, index)
        label = self.getLabels(self._train_labels, index)
        return image,label
    
    def __len__(self):
        return self.img_nums
    
    def getImages(self, image, index):
        img_size_bit = struct.calcsize('>784B')
        start_index = struct.calcsize('>IIII') + index * img_size_bit
        temp = struct.unpack_from('>784B', image, start_index)
        img = np.reshape(temp, (28, 28))
#         img = self.normalization(np.array(temp, dtype=float))
#         print(img)
        return img

    def getLabels(self, label, index):
        lab_size_bit = struct.calcsize('>1B')
        start_index = struct.calcsize('>II') + index * lab_size_bit
        lab = struct.unpack_from('>1B', label, start_index)
        return lab
    
    def normalization(self, x):
        max = float(255)
        min = float(0)
        for i in range(0, 784):
            x[i] = (x[i] - min) / (max - min)
        return x

usermnist_train = UserMNIST(train=True)#how to one data,return numpy(user define) type
usermnist_train_loader = DataLoader(dataset=usermnist_train, batch_size=user_batch_size, shuffle=False)#do somethings to get all data,return tensor type

对我这个类进行测试,测试程序如下所示(输出结果附后):

dataiter = iter(usermnist_train_loader)
images,labels = dataiter.next()
print(images.shape, labels)
print(type(images), type(labels))
plt.imshow(images[0])
plt.show()

img, lab = usermnist_train.__getitem__(6) # get the 34th sample
print(type(img))
print(type(lab))
plt.imshow(img)
plt.show()

其实关于这个类的重载最重要的是其使用方法,也就是这2个语句:

usermnist_train = UserMNIST(train=True)#how to one data,return numpy(user define) type
usermnist_train_loader = DataLoader(dataset=usermnist_train, batch_size=user_batch_size, shuffle=False)#do somethings to get all data,return tensor type

UserMNIST是继承了Dataset类的实现,从测试样例之中就可以看出来正确继承了,其实例化之后的usermnist_train做为参数通过DataLoader进行数据加载。

我的理解是DataLoader类是一种数据集加载优化的方法,因为如果用户有大量的数据需要加载的话会很花时间,DataLoader类可以通过多线程加载进行提速,而且可以打乱数据集的顺序增加随机性,使得训练集环境更接近实际环境等优点。

[基于pytorch的mnist识别03]运行模型(代码片段)

写在前面今天调通了pytorch模型,同时进行了简单的模型测试知识点总结1:python忽略某些特定语句的warning:importwarningswithwarnings.catch_warnings():#ignoresomewarningswarnings.simplefilter("ignore")loss=criteri 查看详情

图像分类基于pytorch搭建lstm实现mnist手写数字体识别(双向lstm,附完整代码和数据集)(代码片段)

...sdn.net/AugustMe/article/details/128969138文章中,我们使用了基于PyTorch搭建LSTM实现MNIST手写数字体识别,LSTM是单向的,现在我们使用双向LSTM试一试效果,和之前的单向LSTM模型稍微有差别,请注意查看代码的变化。1.导入依赖库这些依赖... 查看详情

[基于pytorch的mnist识别05]总结(代码片段)

写在前面走走停停,我自己做的第一个机器学习的项目:MNIST手写字符集的识别,终于结束了,一路走来也算是踩了大大小小的坑,在这篇文章里做一个总结。模型配置模型总共有4层,输入层有784个神经元... 查看详情

[基于pytorch的mnist识别04]模型调试(代码片段)

知识点1:nn.MSELoss(x,y)均方根值函数如果输入的x和y参数的维度不同,维度低的tensor会进行扩张填充,扩充形式如下图所示:因此之后在训练模型的时候要保持输入变量参数一致,养成良好的编程习惯。2&#x... 查看详情

[基于pytorch的mnist识别01]神经网络建立(代码片段)

写在前面前面我曾尝试在无框架的情况下进行神经网络的构建和调参,我发现虽然网络构建起来和运行起来都问题不大,但是在调参时就会显现无框架的弊端。经过初步的调参之后,我建立的网络识别准确率只能达到... 查看详情

pytorch使用mnist数据集实现手写数字识别(代码片段)

...用mnist数据集实现手写数字识别是入门必做吧。这里使用pyTorch框架进行简单神经网络的搭建。首先导入需要的包。1importtorch2importtorch.nnasnn3importtorch.utils.dataasData4importtorchvision 接下来需要下载mnist数据集。我们创建train_data。使... 查看详情

用pytorch构建基于卷积神经网络的手写数字识别模型(代码片段)

本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052目录一、MINST数据集介绍与分析二、卷积神经网络三、基于卷积神经网络的手写数字识别一、MINST数据集介绍与分析        MINST数据库是机器学习领域... 查看详情

pycharm-mnist手写数字的识别

本文是将PyTorch上的MNIST手写数字的识别、CNN网络在PyCharm上进行展示可以看我这篇博客PyTorch-CNNModel关于MNIST数据集无法下载的问题可以看我另一篇博客PyTorch-MNIST数据集无法下载-采用手动下载的方式在PyCharm中导入PyTorc转载请注明... 查看详情

基于pytorch平台实现对mnist数据集的分类分析(前馈神经网络softmax)基础版(代码片段)

基于pytorch平台实现对MNIST数据集的分类分析(前馈神经网络、softmax)基础版文章目录基于pytorch平台实现对MNIST数据集的分类分析(前馈神经网络、softmax)基础版前言一、基于“前馈神经网络”模型,分类分析... 查看详情

图像分类基于pytorch搭建lstm实现mnist手写数字体识别(单向lstm,附完整代码和数据集)(代码片段)

写在前面:首先感谢兄弟们的关注和订阅,让我有创作的动力,在创作过程我会尽最大能力,保证作品的质量,如果有问题,可以私信我,让我们携手共进,共创辉煌。提起LSTM大家第一反应是在NLP的数据集上比较常见,不过在... 查看详情

基于mnist数据集实现手写数字识别(代码片段)

介绍在TensorFlow的官方入门课程中,多次用到mnist数据集。mnist数据集是一个数字手写体图片库,但它的存储格式并非常见的图片格式,所有的图片都集中保存在四个扩展名为idx*-ubyte.gz的二进制文件。可以直接从官网进... 查看详情

[pytorch系列-34]:卷积神经网络-搭建lenet-5网络与mnist数据集手写数字识别(代码片段)

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121050469目录前言:LeNet网络详解第1章业务领域分析1.1 步骤1-1:... 查看详情

[pytorch系列-40]:卷积神经网络-模型的恢复/加载-搭建lenet-5网络与mnist数据集手写数字识别(代码片段)

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121132377目录第1章模型的恢复与加载1.1概述 1.2 模型的恢复与加载类型1.... 查看详情

pytorch入门实战|第p1周:实现mnist手写数字识别

查看详情

pytorch学习(代码片段)

Pytorch1_手写字识别数据集的载入数据集展示网络搭建模型训练训练损失函数图模型预测下面将构造一个三层全连接层的神经网络来对手写数字进行识别。数据集的载入importtorchfromtorchimportnnfromtorch.nnimportfunctionalasFfromtorchimportoptimim... 查看详情

pytorch手写数字识别项目增量式训练

dataset.py ‘‘‘准备数据集‘‘‘importtorchfromtorch.utils.dataimportDataLoaderfromtorchvision.datasetsimportMNISTfromtorchvision.transformsimportToTensor,Compose,Normalizeimporttorchvisionimportconfigdefmnist_dataset(train):func=torchvision.transforms.Compose([torchvision.transform... 查看详情

我用pytorch复现了lenet-5神经网络(mnist手写数据集篇)!

...介绍了卷积神经网络LeNet-5的理论部分。今天我们将使用Pytorch来实现LeNet-5模型,并用它来解决MNIST数据集的识别。正文开始!一、使用LeNet-5网络结构创建MNIST手写数字识别分类器MNIST是一个非常有名的手写体数字识别数据... 查看详情

我用pytorch复现了lenet-5神经网络(mnist手写数据集篇)!

...介绍了卷积神经网络LeNet-5的理论部分。今天我们将使用Pytorch来实现LeNet-5模型,并用它来解决MNIST数据集的识别。正文开始!一、使用LeNet-5网络结构创建MNIST手写数字识别分类器MNIST是一个非常有名的手写体数字识别数据... 查看详情