pytorch两种不同分类层的设计方法(代码片段)

算法与编程之美 算法与编程之美     2022-10-22     537

关键词:

问题

涉及到图像分类的网络的最后一层分类层,有两种实现方法,如下所示,你更偏向于哪种方法呢?

方法

方法1

import torch
from torch import nn


'''
测试池化和卷积组合的分类层
'''
class MyNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
        self.conv = nn.Conv2d(3, 32, 3, padding=1)
        
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(32, 2)
        
    
    def forward(self, x):
        x = self.conv(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1) # 展开所有元素
        out = self.classifier(x)
        
        return out
    
if __name__ == '__main__':

    from torchsummary import summary
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
    x = torch.rand(size=(1, 3, 7, 7)).to(device)
    net = MyNet().to(device)
    
    summary(net, (3, 7, 7))
    


方法2

import torch
from torch import nn


'''
测试池化和卷积组合的分类层
'''
class MyNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
        self.conv = nn.Conv2d(3, 32, 3, padding=1)
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(32, 2)
        )
        
    
    def forward(self, x):
        x = self.conv(x)
        out = self.classifier(x)
        
        return out
    
if __name__ == '__main__':

    from torchsummary import summary
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
    x = torch.rand(size=(1, 3, 7, 7)).to(device)
    net = MyNet().to(device)
    out = net(x)
    
    summary(net, (3, 7, 7))
    


结语

从扩展性、可读性的角度来说,更偏向于方法2的设计。

如何用pytorch实现一个分类器?(代码片段)

学习目标了解分类器的任务和数据样式掌握如何用Pytorch实现一个分类器分类器任务和数据介绍构造一个将不同图像进行分类的神经网络分类器,对输入的图片进行判别并完成分类.本案例采用CIFAR10数据集作为原始图片数据.CIFAR10数... 查看详情

小白学习pytorch教程十六在多标签分类任务上微调bert模型(代码片段)

...类,并给出了各自的步骤。参考官方教程:https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html复旦大学邱锡鹏老师课题组的研究论文《HowtoFine-TuneBERTforTextClassification?》。论文:https://arxiv.org/pdf/1905.05583.pdf这篇... 查看详情

用一个图像分类实例拿捏pytorch使用方法(代码片段)

写在最前边这篇文章要写的内容看封面,就是要用一篇文章讲解一下,怎么用Fashion-MNIST数据集,我们自己建一个神经网络,训练好之后用它做图片分类。importtorchfromtorchimportnnfromtorch.utils.dataimportDataLoaderfromtorchvisionimportdatasetsfromt... 查看详情

pytorch开源图像分类算法框架(代码片段)

PyTorch的开源图像分类算法框架pytorch-image-models,功能完善,集成大量数据增强方法和主流的网络框架,同时易用。pytorch-image-models:https://github.com/rwightman/pytorch-image-models训练使用方法:训练脚本:train.py 查看详情

基于torchtext的pytorch文本分类(代码片段)

...量的预处理和大量的计算资源。在这篇文章中,我们使用PyTorch来进行多类文本分类,因为它有如下优点:PyTorch提供了一种强大的方法来实现复杂的模型体系结构和算法,其预处理量相对较少,计算资源(包括执行时 查看详情

pytorch快速搭建多个具有相同卷积核尺寸不同卷积核数量的卷积层(代码片段)

问题本文主要介绍啊快速搭建多个具有相同卷积核尺寸不同卷积核数量的卷积层。方法快速搭建三个卷积层结构,每个卷积层的参数如下所示:3@3x316@3x332@3x364@3x3x@yxy其中x为卷积核数量,y为卷积核大小简单方法importtorchfromtorchimport... 查看详情

pytorch快速搭建多个具有相同卷积核尺寸不同卷积核数量的卷积层(代码片段)

问题本文主要介绍啊快速搭建多个具有相同卷积核尺寸不同卷积核数量的卷积层。方法快速搭建三个卷积层结构,每个卷积层的参数如下所示:3@3x316@3x332@3x364@3x3x@yxy其中x为卷积核数量,y为卷积核大小简单方法importtorchfromtorchimport... 查看详情

pytorch1.0中文官方教程:对抗性示例生成(代码片段)

译者:cangyunye作者:NathanInkawhich如果你正在阅读这篇文章,希望你能理解一些机器学习模型是多么有效。现在的研究正在不断推动ML模型变得更快、更准确和更高效。然而,在设计和训练模型中经常会忽视的是安全性和健壮性方面... 查看详情

图像分类convit从入门到实战——使用convit实现植物幼苗的分类(pytorch)(代码片段)

摘要来自Facebook的研究者提出了一种名为ConViT的新计算机视觉模型,它结合了两种广泛使用的AI架构——卷积神经网络(CNN)和Transformer,该模型取长补短,克服了CNN和Transformer本身的一些局限性。同时,借助这两种架... 查看详情

基于pytorch实现的声音分类(代码片段)

前言本章我们来介绍如何使用Pytorch训练一个区分不同音频的分类模型,例如你有这样一个需求,需要根据不同的鸟叫声识别是什么种类的鸟,这时你就可以使用这个方法来实现你的需求了。源码地址:https://github.... 查看详情

pytorch实现文本情感分类流程(代码片段)

文章目录基本概念介绍文本情感分类准备数据集文本的序列化构建模型模型的训练与评估完整代码基本概念介绍tokenization:分词,每个词语就是一个token分词方法:转化为单个字(常见)切分词语N-gram:准... 查看详情

pytorch两种构建网络的方法(代码片段)

版本:Pytorch1.0 代码是在jupter中执行的。导包:importtorchimporttorch.nnasnnimporttorch.nn.functionalasFimporttorch.optimasoptimfromtorchvisionimportdatasets,transforms设置超参:BATCH_SIZE=512#大概需要2G的显存EPOCHS=20#总共训练批次DEVICE=torch.device("cuda"iftor... 查看详情

pytorch不同层设置不同学习率(代码片段)

1主要目标不同的参数可能需要不同的学习率,本文主要实现的是不同层中参数的不同学习率设置。尤其是当我们在使用预训练的模型时,需要对一些除了主干网络以外的分支进行单独修改并进行初始化,其他主干网... 查看详情

php两种获取分类树的方法(代码片段)

php两种获取分类树的方法 1./***获取分类树*@paramarray$array数据源*@paramint$pid父级ID*@paramint$level分类级别*@returnstring*/functiongetCategory($array,$pid=0,$level=0)//声明静态数组,避免递归调用时,多次声明导致数组覆盖static$list=[];foreach($ 查看详情

机器学习实战第7章——利用adaboost元算法提高分类性能(代码片段)

...集成,等等接下来介绍基于同一种分类器多个不同实例的两种不同计算方法bagging和boosting1.bagging  原理:从原始数据集选择S次后得到S个新数据集的一种技 查看详情

6.2pytorch图像的多分类--手写字体识别(代码片段)

欢迎订阅本专栏:《PyTorch深度学习实践》订阅地址:https://blog.csdn.net/sinat_33761963/category_9720080.html第二章:认识Tensor的类型、创建、存储、api等,打好Tensor的基础,是进行PyTorch深度学习实践的重中之重的基础... 查看详情

动手学习pytorch——多层感知机(代码片段)

...哪个容易造成梯度消失,使用较多。   多层感知机pytorch实现如下:importtorchfromtorchimportnnfromtorch.nnimportinitimportnumpyasnpimportsyssys.path.append("/home/kesci/input")importd2lzh1981asd2lnum_inputs,num_outputs,num_hiddens=784,10,256net=nn.Sequential(d2l.Flatte... 查看详情

鞋子,靴子,拖鞋傻傻分不清楚pytorch实现分类入门小案例(代码片段)

鞋子,靴子,拖鞋傻傻分不清楚pytorch入门前言方法网络优化器损失函数总体方法代码实现图片加字神经网络总结前言从入学到现在已经两个多月了,看了一个多月的论文不知道学到了啥正好最近看了看pytorch的入门... 查看详情