pytorch土堆pytorch教程学习torchvision中的数据集的使用(代码片段)

hzyuan hzyuan     2023-05-03     597

关键词:

torchvision 中的数据集使用

torchvision.datasets模块中提供了许多内置的数据集。

内置的数据集有 CIFAR10、MNIST、COCO等,更多可进入 pytorch 官网查看。

所有内置的数据集都继承了 torch.utils.data.Dataset 类,并且实现了 __getitem____len__

所有的数据集几乎都有相似的API。下面以 CIFAR10 数据集的使用为例来认识下内置数据集的用法。

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

\'\'\'
dataset = torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
Args:
root(string):数据集存放的根目录。
train(bool):如果True则从训练集创建数据集,False则从测试集创建数据集。
transform(callable):需要对图像进行的转换操作
target_transforms(callable):需要对 target 进行的转换操作
download(bool):True则从网络下载数据集到根目录。如果数据集已经存在,则不再下载。
\'\'\'
# 创建训练集
train_set = torchvision.datasets.CIFAR10(root=\'./dataset\', train=True, transform=dataset_transform, download=True) 
# 创建测试集
test_set = torchvision.datasets.CIFAR10(root=\'./dataset\', train=False, transform=dataset_transform, download=True)

img, target = test_set[0]  # 取出图像和target
print(img, test_set.classes[target])

# 在 tensorboard 里打开十张图像
writer = SummaryWriter(\'logs\')
for i in range(10):
    img, target = test_set[i]
    writer.add_image(\'test_set\', img, i)
writer.close()

内置数据集很方便地供我们下载使用。根据源码或者官方文档可以了解到创建数据集所需传入的参数,然后需要关注__getitem__ 方法返回的结果是什么

自定义数据集

自己定义的数据集可以参照内置数据集,即继承 torch.utils.data.Dataset 类,并且重写 __getitem____len__
数据存放在 dataset/train里,分为两个目录 antsbees,也分别是数据的标签,如下图所示:

from PIL import Image
from torch.utils.data import Dataset
import os

class MyDataSet(Dataset):

    # 在__init__里加载
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir  # 根目录
        self.label_dir = label_dir  # 标签目录
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)  # 图片路径列表

    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.path, img_name)
        img = Image.open(img_item_path) # 获取数据 
        label = self.label_dir # 获取label
        return img, label

    def __len__(self):
        return len(self.img_path) # 获取数据集长度

# test
root_dir = \'dataset/train\'
ants_label_dir = \'ants\'
bees_label_dir = \'bees\'
# 生成两个数据集
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset  # 拼接两个数据集

img1, label1 = ants_dataset[0]
img1.show()
print(\'label1:\', label1)
img2, label2 = train_dataset[130]
img2.show()
print(\'label2:\', label2)

pytorch土堆pytorch教程学习torchvision中的数据集的使用(代码片段)

...据集。内置的数据集有CIFAR10、MNIST、COCO等,更多可进入pytorch官网查看。所有内置的数据集都继承了torch.utils.data.Dataset类,并且实现了__getitem__和__len__。所有的数据集几乎都有相似的API。下面以CIFAR10数据集的使用为例来认识下内... 查看详情

我是土堆-pytorch教程知识点学习总结笔记(代码片段)

此文章为【我是土堆-Pytorch教程】知识点学习总结笔记(四)包括:神经网络-非线性激活、神经网络-线性层及其他层介绍、神经网络-搭建小实战和Sequential的使用、损失函数与反向传播、优化器、现有网络模型的使用... 查看详情

pytorch土堆pytorch教程学习transforms的使用(代码片段)

transforms在工具包torchvision下,用来对图像进行预处理:数据中心化、数据标准化、缩放、裁剪、旋转、翻转、填充、噪声添加、灰度变换、线性变换、仿射变换、亮度/饱和度/对比度变换等。transforms本质就是一个python文件,相当... 查看详情

我是土堆-pytorch教程知识点学习总结笔记(代码片段)

此文章为【我是土堆- Pytorch教程】知识点学习总结笔记(二)包括:TensorBoard的使用(一)、TensorBoard的使用(二)、Transforms的使用(一)、Transforms的使用(二)、 常见的Transforms(... 查看详情

我是土堆-pytorch教程知识点学习总结笔记(代码片段)

此文章【我是土堆-Pytorch教程】知识点学习总结笔记(一)包括:python学习中的2大法宝、Pycharm及Jupyter使用及对比、PyTorch加载数据初认识、Dataset类代码实战。目录1.python学习中的2大法宝实战操作总结2.Pycharm及Jupyter使... 查看详情

pytorch学习总结

torch.zeroshttps://pytorch.org/docs/stable/generated/torch.zeros.htmltorch.repeathttps://pytorch.org/docs/stable/generated/torch.Tensor.repeat.htmltorch.unsqueezehttps://pytorch.org/docs/stable/genera 查看详情

深度学习pytorch和torch?

为什么我们称呼为pytorch,而在编写时,写作importtorch而不是importpytorch?参考技术A项目名叫Pytorch,模块名叫torch,这个问题没有为什么,开发者就是这么定义的。就跟OpenCV一样项目名叫OpenCV,模块名叫cv2。项目名叫Pillow,模块名叫... 查看详情

pytorch学习问题汇总

问题1:torch各个类型数据格式如何转换?数据类型在官方文档torch.Tensor中,有八种类型。#尝试一i32=torch.IntTensor([1,2,3])i64=torch.LongTensor([1,2,3])#两种转换都报错#new_i64=torch.IntTensor(i64)#new_i32=torch.LongTensor(i32)#didn‘tmatchbecauseso 查看详情

pytorch深度学习实战-2-pytorch和numpy互相转换

importtorchimportnumpy#创建tensortorch_data=torch.Tensor([1,-2,3])#Torch-->arraynp_data=torch_data.numpy()#array-->Torchtorch_data=torch.from_numpy(np_data) 查看详情

pytorch深度学习实战-2-pytorch和numpy互相转换

importtorchimportnumpy#创建tensortorch_data=torch.Tensor([1,-2,3])#Torch-->arraynp_data=torch_data.numpy()#array-->Torchtorch_data=torch.from_numpy(np_data) 查看详情

ai常用框架和工具丨12.深度学习框架pytorch

深度学习框架PyTorch,AI常用框架和工具之一。理论知识结合代码实例,希望对您有所帮助。文章目录环境说明PyTorch安装一、PyTorch简介1.1Torch1.2从Torch到PyTorch1.3PyTorch二、PyTorch中张量操作2.1torch2.2张量的常见操作2.3张量的其他操作... 查看详情

torch教程[3]使用pytorch自带的反向传播

#-*-coding:utf-8-*-importtorchfromtorch.autogradimportVariabledtype=torch.FloatTensor#dtype=torch.cuda.FloatTensor#UncommentthistorunonGPU#Nisbatchsize;D_inisinputdimension;#Hishiddendimension;D_outis 查看详情

pytorch60分钟入门教程:pytorch深度学习官方入门中文教程(代码片段)

什么是 PyTorch?PyTorch是一个基于Python的科学计算包,主要定位两类人群:NumPy的替代品,可以利用GPU的性能进行计算。深度学习研究平台拥有足够的灵活性和速度开始学习Tensors(张量)Tensors类似于NumPy的ndarrays,同时 Tensors可... 查看详情

window10系统如何安装pytorch教程

...s系统电脑安装深度学习教程对于深度学习来说,安装pytorch和tensorflow是基础中的基础,毕竟无法搭建环境就无法进行训练程序,我也是首先对window安装好pytorch和tenso 查看详情

pytorch是啥?

PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。它是一个基于Python的可续计算包,提供两个高级功能:具有强大的GPU加速的张量计算(如NumPy... 查看详情

小白学习pytorch教程十七基于torch实现unet图像分割模型(代码片段)

@Author:Runsen在图像领域,除了分类,CNN今天还用于更高级的问题,如图像分割、对象检测等。图像分割是计算机视觉中的一个过程,其中图像被分割成代表图像中每个不同类别的不同段。上面图片一段代表... 查看详情

pytorch深度学习60分钟闪电战(代码片段)

https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html官方推荐的一篇教程Tensors#Constructa5x3matrix,uninitialized:x=torch.empty(5,3)#Constructarandomlyinitializedmatrix:x=torch.rand(5,3)#Constru 查看详情

pytorch学习(代码片段)

importtorchimportnumpyasnpnp_data=np.arange(6).reshape(2,3)torch_data=torch.from_numpy(np_data)tensor2array=torch_data.numpy()data=[-1,-1,2,-2]tensor=torch.FloatTensor(data)data=[[1,2],[3,4]]tensor=torch.FloatTensor(data)print(‘numpy‘,np_data,‘torch‘,torch_data,‘tensor2array‘,tensor2arra... 查看详情