pytorch迁移学习(代码片段)

ayanwan ayanwan     2022-11-06     797

关键词:

在很多场合中,没有必要从头开始训练整个卷积网络(随机初始化参数),因为没有足够丰富的数据集,而且训练也是非常耗时、耗资源的过程。通常,采用pretrain a ConvNet的方式,然后用ConvNet作为初始化或特征提取器。有两种迁移学习,对应着不同的应用场景。
  • 微调ConvNet:使用已有的model参数代替随机初始化参数进行训练。
  • ConvNet做为特征提取器:我们需要冻结所有的网络权重的更新,最后一层(全连接层)除外。通常,最后一个全连接层是需要根据需求进行修改,并使用一个新的随机权重进行训练。显然,整个网络只有这个层被训练。

pytorch提供了很多pre-trained models,如下:


下面以cifar10为例,cifar10有10类图像 ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')。我们将采用采用第二种方式,修改resnet-18的全连层,以达到cifar10识别目的。

加载数据

print('==> Preparing data..')
transform_train = transforms.Compose([
    #transforms.RandomCrop(224, padding=4),
    transforms.Scale(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Scale(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='../data/cifar', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='../data/cifar', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=2)


加载并修改模型

# ConvNet
model_ft = models.resnet18(pretrained=True)
print(model_ft)

for i, param in enumerate(model_ft.parameters()):
    param.requires_grad = False # 冻结参数的更新

num_ftrs = model_ft.fc.in_features #重新定义fc层,此时,会进行参数的更新。
model_ft.fc = nn.Linear(num_ftrs, 10)
print(model_ft)


训练

def train(epoch):
    model_ft.train()
    for batch_idx, (data, target) in enumerate(trainloader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)

        optimizer.zero_grad()
        output = model_ft(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % args.log_interval == 0:
            print('Train Epoch:  [/ (:.0f%)]\\tLoss: :.6f'.format(
                epoch, batch_idx * len(data), len(trainloader.dataset),
                100. * batch_idx / len(trainloader), loss.data[0]))

完整代码可以查看: tfygg/pytorch-tutorials

pytorch迁移学习(代码片段)

在很多场合中,没有必要从头开始训练整个卷积网络(随机初始化参数),因为没有足够丰富的数据集,而且训练也是非常耗时、耗资源的过程。通常,采用pretrainaConvNet的方式,然后用ConvNet作为初始... 查看详情

pytorch迁移学习实战天气识别(代码片段)

代码:pytorch迁移学习实战,天气识别-深度学习文档类资源-CSDN文库dataset importrandomimportnumpyasnpimporttorchfromtorchvisionimporttransformsimporttorchvisionif__name__==\'__main__\':importmatplotlib.pyplotasplt#定义数据处理transformtransform=transforms.Compose([tr... 查看详情

8.1pytorch模型迁移(代码片段)

欢迎订阅本专栏:《PyTorch深度学习实践》订阅地址:https://blog.csdn.net/sinat_33761963/category_9720080.html第二章:认识Tensor的类型、创建、存储、api等,打好Tensor的基础,是进行PyTorch深度学习实践的重中之重的基础... 查看详情

[pytorch系列-49]:卷积神经网络-迁移学习的统一处理流程与软件架构-pytorch代码实现(代码片段)

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121312731目录第1章关于FineTuning与TransferTrainning概述1.1理论基础1.2迁移学习... 查看详情

pytorch基础——迁移学习(代码片段)

一、介绍内容使机器能够“举一反三”的能力知识点使用PyTorch的数据集套件从本地加载数据的方法迁移训练好的大型神经网络模型到自己模型中的方法迁移学习与普通深度学习方法的效果区别两种迁移学习方法的区别二、从图... 查看详情

pytorch实例3——迁移学习(代码片段)

...模式4.3.2.2固定值模式4.4结论1.实验环境JupyterNotebookPython3.7PyTorch1.4.02.实验目的迁移学习,让机器拥有能够“举一反三”的能力。本次实验就以“是蚂蚁还是蜜蜂”为例,探索如何将已训练好的大网络迁移到小数据集上࿰... 查看详情

小白学习pytorch教程十二迁移学习:微调vgg19实现图像分类(代码片段)

@Author:Runsen前言:迁移学习就是利用数据、任务或模型之间的相似性,将在旧的领域学习过或训练好的模型,应用于新的领域这样的一个过程。从这段定义里面,我们可以窥见迁移学习的关键点所在,... 查看详情

8.1pytorch模型迁移(代码片段)

欢迎订阅本专栏:《PyTorch深度学习实践》订阅地址:https://blog.csdn.net/sinat_33761963/category_9720080.html第二章:认识Tensor的类型、创建、存储、api等,打好Tensor的基础,是进行PyTorch深度学习实践的重中之重的基础... 查看详情

小白学习pytorch教程十四迁移学习:微调resnet实现男人和女人图像分类(代码片段)

@Author:Runsen上次微调了Alexnet,这次微调ResNet实现男人和女人图像分类。ResNet是ResidualNetworks的缩写,是一种经典的神经网络,用作许多计算机视觉任务。ResNet论文参见此处:https://arxiv.org/abs/1512.03385该模型... 查看详情

小白学习pytorch教程十三迁移学习:微调alexnet实现ant和bee图像分类(代码片段)

@Author:Runsen上次微调了VGG19,这次微调Alexnet实现ant和bee图像分类。多年来,CNN许多变体已经发展起来,从而产生了几种CNN架构。其中最常见的是:LeNet-5(1998)AlexNet(2012)ZFNet(2013)GoogleNet/Inception(2014 查看详情

pytorch迁移学习教程(计算机视觉应用实例)(代码片段)

文章目录迁移学习什么是迁移学习为何用迁移学习迁移学习的优点迁移学习的方法迁移方法的选择学习目标下载数据导入模块数据增强制作数据集数据加载器相关信息的打印训练数据可视化训练模型参数微调的方法特征提取的方... 查看详情

pytorch测试迁移学习

 训练源码:源码仓库:https://github.com/pytorch/tutorials迁移学习测试代码:tutorials/beginner_source/transfer_learning_tutorial.py 准备工作:下载数数据集:https://download.pytorch.org/tutorial/hymenoptera_data.zip &n 查看详情

pytorch迁移学习(代码片段)

...权重进行训练。显然,整个网络只有这个层被训练。pytorch提供了很多pre-trainedmodels,如下:下面以cifar10为例,cifar10有10类图像('plane','car','bird','cat','deer','dog','frog','horse... 查看详情

动手学pytorch-迁移学习(代码片段)

迁移学习1.基本概念2.Fine-tuning3.Fixed1.基本概念假设我们想从图像中识别出不同种类的椅子,然后将购买链接推荐给用户。一种可能的方法是先找出100种常见的椅子,为每种椅子拍摄1,000张不同角度的图像,然后在收集到的图像数... 查看详情

pytorch实战用pytorch实现基于神经网络的图像风格迁移(代码片段)

用PyTorch实现基于神经网络的图像风格迁移1.风格迁移原理介绍2.FastNeuralStyle网络结构3.用PyTorch实现风格迁移3.1首先看看如何使用预训练的VGG。3.2接下来要实现风格迁移网络参考资料风格迁移,又称为风格转换。只需要给定原... 查看详情

基于深度学习的语义分割初探fcn以及pytorch代码实现(代码片段)

基于深度学习的语义分割初探FCN以及pytorch代码实现FCN论文论文地址:https://arxiv.org/abs/1411.4038FCN是基于深度学习方法的第一篇关于语义分割的开山之作,虽然这篇文章的分割结果现在看起来并不是目前最好的,但其意... 查看详情

如何使用 Pytorch 将二进制迁移学习模型扩展到多个图像类别?

】如何使用Pytorch将二进制迁移学习模型扩展到多个图像类别?【英文标题】:HowtoextendabinarytransferlearningmodelwithPytorchtomultipleimagecategories?【发布时间】:2021-07-0813:46:43【问题描述】:我正在处理一些代码,这些代码使用ResNet-18模... 查看详情

4.6pytorch使用损失函数(代码片段)

欢迎订阅本专栏:《PyTorch深度学习实践》订阅地址:https://blog.csdn.net/sinat_33761963/category_9720080.html第二章:认识Tensor的类型、创建、存储、api等,打好Tensor的基础,是进行PyTorch深度学习实践的重中之重的基础... 查看详情