(机器学习深度学习常用库框架|pytorch篇)第三节:pytorch之torchvision详解(代码片段)

快乐江湖 快乐江湖     2023-01-04     800

关键词:

文章目录

一:torchvision概述

torchvisiontorchvision是Pytorch的一个图形库,主要用来构建计算机视觉模型,torchvision由以下四个部分构成

  • torchvision.datasets:包括一些加载数据的函数和常用的数据集接口
  • torchvision.models:包含常用的模型结构(含预训练模型),例如AlexNet、ResNet等等
  • torchvision.transforms:包含一些常见的图片变换,例如裁剪、旋转等等
  • torchvision.utils:其他用法

二:torchvision.datasets

torchvision.datasets:该模块下既有官方提供的数据集,也有自定义数据集的类,两者都是torch.utils.data.Dataset的子类,因此可以直接输入到torch.utils.data.DataLoader中去

(1)官方数据集

torchvision.datasets中提供的官方数据如下,这些数据集详细介绍见此文:数据集介绍

MNIST
Fashion-MNIST
KMNIST
EMNIST
FakeData
COCO
Captions
Detection
LSUN
​ImageFolder
DatasetFolder
Imagenet-12
CIFAR
STL10
SVHN
PhotoTour
SBU
Flickr
VOC
Cityscapes
...

这里我们以MNIST数据集为例,演示一下这些官方数据集如何加载,其余数据集的加载和MNIST一致

如下,使用torchvision.datasets.MNIST加载MNIST数据集

train_data = dataset.MNIST(root='./mnist/',
                           train=True,
                           transform=transforms.ToTensor(),
                           download=True)
test_data = dataset.MNIST(root='./mnist/',
                           train=False,
                           transform=transforms.ToTensor(),
                           download=False)
  • root:表示数据集待存放的目录
  • train:如果为true将会使用训练集的数据集(training.pt),如果为false将会使用测试集数据集(test.pt
  • download:如果为true将会从网络上下载并放入root中,如果数据集已下载则不会再次下载
  • transform:接受PIL图片并返回转换后的图片,常用的就是转换为tensor(这里便会调用torchvision.transform

数据集加载成功后,文件布局如下

(2)自定义数据集类

这里的自定义数据集类指的主要是torchvision.datasets.ImageFolder(),它继承自 torchvision.datasets.DatasetFolder(),后者又继承自 torchvision.datasets.VisionDataset(),而VisionDataset 则是 torch.utils.data.Dataset 的子类

torchvision.datasets.CIFAR数据集为例说明如何使用torchvision.datasets.ImageFolder(),这里的torchvision.datasets.CIFAR我已经将其转换为png格式存储

  • 下载链接
  • CIFAR10有60000张图片,共分为10个类别,其中50000张为训练图片(每个类别5000张),10000张为测试图片(每个类别1000张)

图片文件布局如下,torchvision.datasets.ImageFolder()要求你的图片数据必须按照以下方式进行组织

torchvision.datasets.ImageFoler参数说明

  • root:图片存储的根目录,即各类别文件夹所在目录的上一级目录
  • transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片
  • target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
  • loader:表示数据集加载方式,通常默认加载方式即可
torchvision.datasets.ImageFolder(root,transform,target_transform,loader)

如下,使用torchvision.datasets.ImageFoler对前面的图片进行加载

  • 注意transforms部分可暂时忽略
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop((28, 28)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(90),
    transforms.RandomGrayscale(0.1),
    transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
    transforms.ToTensor()
])

train_dataset = torchvision.datasets.ImageFolder(root='./data/train', transform=train_transforms)
test_dataset = torchvision.datasets.ImageFolder(root='./data/test', transform=transforms.ToTensor)
print(len(train_dataset))
print(len(test_dataset))

同时,通过torchvision.datasets.ImageFolder生成的train_datasettest_dataset还有如下3个成员变量

  • self.classes:使用一个list保存类别名称
  • self.class_to_idx:类别对应的索引
  • self.imgs:是一个list,每个元素是一个tuple,每个tuple保存的是(img-path, class)
print(train_dataset.classes[: 5])
print("-"*30)
print(train_dataset.class_to_idx)
print("-"*30)
print(train_dataset.imgs[: 5])

(3)ImageFolder手动实现

仍然以上述CIFAR10数据集为例,我们手动实现一下ImageFolder,这对你理解它大有帮助

import torchvision.datasets
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import numpy as np
import glob


# 类别名字
label_name = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck"
]
# 类比名字映射索引
label_dict = 
for idx, name in enumerate(label_name):
    label_dict[name] = idx


def default_loader(path):
    return Image.open(path).convert("RGB")

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop((28, 28)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(90),
    transforms.RandomGrayscale(0.1),
    transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
    transforms.ToTensor()
])

class MyDataset(Dataset):
    """
        im_list:是一个列表,每一个元素是图片路径
        transform:对图片进行增强
        loader:使用PIL对图片进行加载
    """
    def __init__(self, im_list, transform=None, loader=default_loader):
        super(MyDataset, self).__init__()
        # imgs为二维列表,每一个子列表中第一个元素存储im_list,第二个通过label_dict映射为索引
        imgs = []

        for im_item in im_list:
            # 路径'./data/test/airplane/aeroplane_s_000002.png'中倒数第二个是标签名
            im_label_name = im_item.split("\\\\")[-2]
            imgs.append([im_item, label_dict[im_label_name]])

        self.imgs = imgs
        self.transform = transform
        self.loader = loader

    def __getitem__(self, index):
        im__path, im_label = self.imgs[index]

        # 会调用PIL加载图片数据
        im_data = self.loader(im__path)
        # 如果给了transoform那么就对图片进行增强
        if self.transform is not None:
            im_data = self.transform(im_data)

        return im_data, im_label

    def __len__(self):
        return len(self.imgs)


if __name__ == '__main__':
    im_train_list = glob.glob(r'./data/train/*/*.png')
    im_test_list = glob.glob(r'./data/test/*/*.png')

    train_dataset = MyDataset(im_train_list, transform=train_transforms)
    test_dataset = MyDataset(im_test_list, transform=transforms.ToTensor())
    print(len(train_dataset))
    print(len(test_dataset))

    train_loader = DataLoader(dataset=train_dataset, batch_size=6, shuffle=True, num_workers=0)
    test_loader = DataLoader(dataset=test_dataset, batch_size=6, shuffle=False, num_workers=0)

三:torchvision.transforms

torchvision.transforms:该模块是Pytorch中的图像预处理包,包含了一些常用的图像变换,主要实现对数据集的预处理、数据增强,转化为tensor等操作

使用时如果有很多变换,那么一般会使用Compose将这些步骤给整合到一起

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop((28, 28)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(90),
    transforms.RandomGrayscale(0.1),
    transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
    transforms.ToTensor()
])

train_dataset = torchvision.datasets.ImageFolder(root='./data/train', transform=train_transforms

如果变换时只有一种,那么一般会直接给到形参,比如最常使用到的ToTensor

test_dataset = torchvision.datasets.ImageFolder(root='./data/test', transform=transforms.ToTensor)

torchvision.transforms涉及变换主要有以下4类

  • 裁剪

    • 中心裁剪transforms.CenterCrop
    • 随机裁剪transforms.RandomCrop
    • 随机长宽比裁剪transforms.RandomResizedCrop
    • 上下左右中心裁剪transforms.FiveCrop
    • 上下左右中心裁剪后翻转transforms.TenCrop
  • 翻转和旋转

    • 依概率p水平翻转transforms.RandomHorizontalFlip(p=0.5)
    • 依概率p垂直翻转transforms.RandomVerticalFlip(p=0.5)
    • 随机旋转transforms.RandomRotation
  • 图像变换和转换

    • 变换为某一尺寸transforms.Resize
    • 标准化transforms.Normalize
    • 转化为tensor并归一化transforms.ToTensor
    • 填充transforms.Pad
    • 修改亮度、对比度和饱和度transforms.ColorJitter
    • 转化为灰度图transforms.Grayscale
    • 线性变化transforms.LinearTransformation
    • 仿射变换transforms.RandomAffine
    • 依概率p转化为灰度图transforms.RandomGrayscale
    • 将数据转化为PILImagetransforms.ToPILImage
  • 其他操作

    • transforms操作使数据增强更灵活transforms.RandomChoice(transforms)
    • 从给定的一系列transforms中选定一个操作transforms.RandomApply(transforms, p=0.5)
    • 给一个transform加上概率进行操作transforms.RandomOrder

四:torchvision.models

torchvision.models:该模块提供了很多图像处理中的常用模型,并且提供了与训练版本,主要有

  • AlexNet: AlexNet variant from the “One weird trick” paper.
  • VGG: VGG-11, VGG-13, VGG-16, VGG-19 (with and without batch normalization)
  • ResNet: ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152
  • SqueezeNet: SqueezeNet 1.0, and SqueezeNet 1.1

(机器学习深度学习常用库框架|pytorch篇)第一节:pytorch简介和其核心概念(代码片段)

文章目录一:什么是Pytorch二:Pytorch优势三:Pytorch三大核心概念(1)tensor(张量)(2)autograd(自动微分-变量)(3)nn.Module(神经网络)四:tensor和机器学习、深度... 查看详情

(机器学习深度学习常用库框架|pytorch篇)第三节:pytorch之torchvision详解(代码片段)

...hvision.models一:torchvision概述torchvision:torchvision是Pytorch的一个图形库,主要用来构建计算机视觉模型,torchvision由以下四个部分构成torchvision.datasets:包括一些加载数据的函数和常用的数据集接口torchvision.models&... 查看详情

深度学习-pytorch框架实战系列

深度学习-PyTorch框架实战系列PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。它是一个基于Python的可续计算包,提供两个高级功能... 查看详情

1.pytorch是什么?(代码片段)

这篇博客将介绍PyTorch深度学习库,包括:PyTorch是什么如何安装PyTorch重要的PyTorch功能,包括张量和自动标记PyTorch如何支持GPU为什么PyTorch在研究人员中如此受欢迎PyTorch是否优于Keras/TensorFlow是否应该在项目中使用PyTorch... 查看详情

3.使用pytorch深度学习库训练第一个卷积神经网络cnn(代码片段)

这篇博客将介绍如何使用PyTorch深度学习库训练第一个卷积神经网络(CNN)。训练CNN使用KMNIST数据集(MNISTdigits数据集的替代品,内置在PyTorch中)识别手写平假名字符(handwrittenHiraganacharacters)。在图像... 查看详情

ai常用框架和工具丨12.深度学习框架pytorch

深度学习框架PyTorch,AI常用框架和工具之一。理论知识结合代码实例,希望对您有所帮助。文章目录环境说明PyTorch安装一、PyTorch简介1.1Torch1.2从Torch到PyTorch1.3PyTorch二、PyTorch中张量操作2.1torch2.2张量的常见操作2.3张量的其他操作... 查看详情

ai人工智能机器学习深度学习学习路径及推荐书籍

要学习Pytorch,需要掌握以下基本知识:编程语言:Pytorch使用Python作为主要编程语言,因此需要熟悉Python编程语言。线性代数和微积分:Pytorch主要用于深度学习领域,深度学习是基于线性代数和微积分的,因此需要具备线性代数... 查看详情

嵌入式学深度学习:1pytorch框架搭建(代码片段)

嵌入式学深度学习:1、Pytorch框架搭建1、介绍2、Pytorch开发环境搭建2.1、查看GPU是否支持CUDA2.2、安装Miniconda2.3、使用Conda安装pytorch2.4、安装常用库3、简单使用验证1、介绍深度学习是机器学习的一种,如下图:目前深... 查看详情

《深度学习与计算机视觉算法原理框架应用》pdf+《深度学习之pytorch实战计算机视觉》pdf

...下载:https://pan.baidu.com/s/1P0-o29x0ZrXp8WotN7GzcA《深度学习之PyTorch实战计算机视觉》更多分享:https://pan.baidu.com/s/1g4hv05UZ_w92uh9NNNkCaA《深度学习与计算机视觉算法原理框架应用》共13章,分为2篇。第1篇基础知识,介绍了人工智能发展... 查看详情

pytorchpytorch基础第0章(代码片段)

本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052这是目录PyTorch的简介PyTorch构建深度学习模型的步骤搭建pytorch使用环境PyTorch的简介PyTorch是一个开源的机器学习框架,由Facebook的人工智能研究院(F... 查看详情

从tensorflow到pytorch:九大深度学习框架哪款最适合你?

开源的深度学习神经网络正步入成熟,而现在有许多框架具备为个性化方案提供先进的机器学习和人工智能的能力。那么如何决定哪个开源框架最适合你呢?本文试图通过对比深度学习各大框架的优缺点,从而为各位读者提供一... 查看详情

tensorfloworpytorch

...比较突出,他们是两个最流行的深度学习库:TensorFlow和PyTorch。你没有办法指出这两个库有什么本质的不同,不用担心!我将在这网络上无休止的存储空间中添加一篇新的文章,也许可以帮你弄清楚一些问题。我将简要的快速的... 查看详情

csdn独家|全网首发|pytorch深度学习·理论篇(2023版)目录

很高兴和大家在这里分享我的最新专栏 Pytorch深度学习·理论篇(2023版),恭喜本博客浏览量达到两百万,CSDN内容合伙人,CSDN人工智能领域实力新星~0Pytorch深度学习·理论篇+实战篇(2023版)大纲1Pytorch深度学习·理论篇... 查看详情

人工智能学习

...二阶段:编程python工具库实战/python网络爬虫第三阶段:机器学习机器学习入门/机器学习提升第四阶段:数据挖掘实战数据挖掘入门/数据分析实战第五阶段:深度学习深度学习网络与框架/深度学习项目实战 https://blog.cs... 查看详情

pytorch是啥?

PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。它是一个基于Python的可续计算包,提供两个高级功能:具有强大的GPU加速的张量计算(如NumPy... 查看详情

深度学习-机器视觉学习路线

...ASK-FasterRCNN2、平台Tensorflow\Caffe\Pytorch3、分析工具python及相应的依赖库 numpy pandas matplotlib scipy4、前沿知识关注GAN、迁移学习等5、开源数据集应... 查看详情

想学深度学习开发,需要提前掌握哪些python知识?

...是大数据竞赛的热门),下面转战深度学习:现在推荐用pytorch,今年的顶会pytorch占了半壁江山,这个比tensorflow,keras简单易懂,而且功能强大,pytorch的库叫做torch,和torchvision同时服用效果更加(还可以加torchnet等一些其他模块... 查看详情

人工智能领域常用的开源框架和库(含机器学习/深度学习/强化学习/知识图谱/图神经网络)

...hon编程语言的开源框架和库,因此全面性肯定有限!一、机器学习常用的开源框架和库1.Scikit-learn作为专门面向机器学习的Python开源框架,Scikit-learn内部实现了多种机器学习算法,容易安装和使用 查看详情