vit-pytorch实现mobilevit注意力可视化(代码片段)

白十月 白十月     2023-03-02     733

关键词:

项目链接 https://github.com/lucidrains/vit-pytorch

注意一下参数设置:

Parameters

  1. image_size: int.
    Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
  2. patch_size: int.
    Number of patches. image_size must be divisible by patch_size.
    The number of patches is: n = (image_size // patch_size) ** 2 and n must be greater than 16.
  3. num_classes: int.
    Number of classes to classify.
  4. dim: int.
    Last dimension of output tensor after linear transformation nn.Linear(…, dim).
  5. depth: int.
    Number of Transformer blocks.
  6. heads: int.
    Number of heads in Multi-head Attention layer.
  7. mlp_dim: int.
    Dimension of the MLP (FeedForward) layer.
  8. channels: int, default 3.
    Number of image’s channels.
  9. dropout: float between [0, 1], default 0…
    Dropout rate.
  10. emb_dropout: float between [0, 1], default 0.
    Embedding dropout rate.
  11. pool: string, either cls token pooling or mean pooling

image_size:表示图像大小的整数。图片应该是正方形的,并且image_size必须是宽度和高度中的最大值。
patch_size:表示补丁大小的整数。image_size必须能被 整除patch_size。补丁的数量计算为n =
(image_size // patch_size) ** 2并且n必须大于 16。 num_classes:一个整数,表示要分类的类数。
dim:一个整数,表示线性变换后输出张量的最后一维nn.Linear(…, dim)。 depth:一个整数,表示
Transformer 块的数量。 heads:一个整数,表示多头注意力层中的头数。 mlp_dim:一个整数,表示
MLP(前馈)层的维度。 channels:一个整数,表示图像中的通道数,默认值为3。 dropout:一个介于 0 和 1
之间的浮点数,代表辍学率。 emb_dropout:一个介于 0 和 1 之间的浮点数,表示嵌入丢失率。
pool:表示池化方法的字符串,可以是“cls token pooling”或“mean pooling”。

快速使用实例

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

SimpleViT

来自原论文的一些作者的更新建议对ViT进行简化,使其能够更快更好地训练。

这些简化包括2d正弦波位置嵌入、全局平均池(无CLS标记)、无辍学、批次大小为1024而不是4096,以及使用RandAugment和MixUp增强。他们还表明,最后的简单线性并不明显比原始MLP头差。

你可以通过导入SimpleViT来使用它,如下图所示

import torch
from vit_pytorch import SimpleViT

v = SimpleViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

可视化
Accessing Attention
If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below

import torch
from vit_pytorch.vit import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

# import Recorder and wrap the ViT

from vit_pytorch.recorder import Recorder
v = Recorder(v)

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 256, 256)
preds, attns = v(img)

# there is one extra patch due to the CLS token

attns # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch)


本文介绍了 MobileViT,一种用于移动设备的轻量级通用视觉转换器。MobileViT 为全球信息处理与转换器提供了不同的视角。

您可以将其与以下代码一起使用(例如 mobilevit_xs)

import torch
from vit_pytorch.mobile_vit import MobileViT

mbvit_xs = MobileViT(
    image_size = (256, 256),
    dims = [96, 120, 144],
    channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
    num_classes = 1000
)

img = torch.randn(1, 3, 256, 256)

pred = mbvit_xs(img) # (1, 1000)

mobilevit:挑战mobilenet端侧霸主

...任务构建一个轻量级、低延迟的网络?为此,作者提出了MobileViT,一种用于移动设备的轻量级通用视觉Transformer。那现在!是否有可能结合CNN和ViT的优势,为移动视觉任务构建一个轻量级、低延迟的网络?为此,作者提出了MobileV... 查看详情

fasternet:cvpr2023年最新的网络,基于部分卷积pconv,性能远超mobilenet,mobilevit

文章目录摘要1、简介2、相关工作4、实验结果4.1.PConv是快速的,但具有很高的FLOPS4.2.PConv与PWConv同时有效4.3.fastnet对ImageNet-1k分类4.4.下游任务的fastnet4.5、消融实现5、结论附录A.ImageNet-1k实验设置B.下游任务实验设置C.ImageNet-1k上... 查看详情

芒果改进yolov7系列:全网首发最新iclr2022顶会|轻量通用的mobilevit结构transformer,轻量级通用且移动友好的视觉转换器,高效涨点

查看详情

注意力模型实现 Keras/Theano

】注意力模型实现Keras/Theano【英文标题】:AttentionModelimplementationKeras/Theano【发布时间】:2016-09-1116:02:21【问题描述】:我正在寻找使用Keras/Theano的机器翻译注意力模型的实现。我遇到过像土拨鼠这样的库,但我正在寻找一些基... 查看详情

tensorflow实现自注意力机制(self-attention)(代码片段)

TensorFlow实现自注意力机制(Self-attention)自注意力机制(Self-attention)计算机视觉中的自注意力Tensorflow实现自注意力模块自注意力机制(Self-attention)自注意力机制(Self-attention)随着自然语言处理(NaturalLanguageProcessing,NLP)模型࿰... 查看详情

深度学习中一些注意力机制的介绍以及pytorch代码实现(代码片段)

文章目录前言注意力机制软注意力机制代码实现硬注意力机制多头注意力机制代码实现参考前言因为最近看论文发现同一个模型用了不同的注意力机制计算方法,因此懵了好久,原来注意力机制也是多种多样的,为了... 查看详情

机器翻译注意力机制及其pytorch实现(代码片段)

前面阐述注意力理论知识,后面简单描述PyTorch利用注意力实现机器翻译EffectiveApproachestoAttention-basedNeuralMachineTranslation简介Attention介绍在翻译的时候,选择性的选择一些重要信息。详情看这篇文章 。本着简单和有效的原则,... 查看详情

注意力实现的预测对齐中的差异化问题

】注意力实现的预测对齐中的差异化问题【英文标题】:DifferentiationIssueinPredictiveAlignmentforAttentionImplementation【发布时间】:2019-11-0207:44:54【问题描述】:我正在尝试基于这篇论文实现local-pattention:https://arxiv.org/pdf/1508.04025.pdf具... 查看详情

注意力经济当前,如何实现高效阅读?

...#xff01;接下来将是有史以来你看过的最短的一篇文章!注意力经济当前,什么是阻碍?答:广告频现!碎片化阅读时代,什么最难得?答:沉浸式阅读!CSDN浏览器助手推出最新版本,新增“... 查看详情

se注意力模块原理分析与代码实现(代码片段)

前言本文介绍SE注意力模块,它是在SENet中提出的,SENet是ImageNet2017的冠军模型;SE模块常常被用于CV模型中,能较有效提取模型精度,所以给大家介绍一下它的原理,设计思路,代码实现,如何应用... 查看详情

se注意力模块原理分析与代码实现(代码片段)

前言本文介绍SE注意力模块,它是在SENet中提出的,SENet是ImageNet2017的冠军模型;SE模块常常被用于CV模型中,能较有效提取模型精度,所以给大家介绍一下它的原理,设计思路,代码实现,如何应用... 查看详情

se注意力模块原理分析与代码实现(代码片段)

前言本文介绍SE注意力模块,它是在SENet中提出的,SENet是ImageNet2017的冠军模型;SE模块常常被用于CV模型中,能较有效提取模型精度,所以给大家介绍一下它的原理,设计思路,代码实现,如何应用... 查看详情

eca注意力模块原理分析与代码实现(代码片段)

前言本文介绍ECA注意力模块,它是在ECA-Net中提出的,ECA-Net是2020CVPR中的论文;ECA模块可以被用于CV模型中,能提取模型精度,所以给大家介绍一下它的原理,设计思路,代码实现,如何应用在模型... 查看详情

eca注意力模块原理分析与代码实现(代码片段)

前言本文介绍ECA注意力模块,它是在ECA-Net中提出的,ECA-Net是2020CVPR中的论文;ECA模块可以被用于CV模型中,能提取模型精度,所以给大家介绍一下它的原理,设计思路,代码实现,如何应用在模型... 查看详情

python实现注意力机制

...理能力。在处理视觉数据的初期,人类视觉系统会迅速将注意力集中在场景中的重要区域上,这一选择性感知机制极大地减少了人类视觉系统 查看详情

sk注意力模块原理分析与代码实现(代码片段)

前言本文介绍SK模块,一种通道注意力模块,它是在SK-Nets中提出的,SK-Nets是2019 CVPR中的论文;SK模块可以被用于CV模型中,能提取模型精度,所以给大家介绍一下它的原理,设计思路,代码实现ÿ... 查看详情

sk注意力模块原理分析与代码实现(代码片段)

前言本文介绍SK模块,一种通道注意力模块,它是在SK-Nets中提出的,SK-Nets是2019 CVPR中的论文;SK模块可以被用于CV模型中,能提取模型精度,所以给大家介绍一下它的原理,设计思路,代码实现ÿ... 查看详情

前沿系列--transform架构[架构分析+代码实现](代码片段)

...体任务使用输入部分EmbeddingPositionEncodingwhy实现注意部分注意力机制/自注意力掩码作用如何工作形状解释完整实现多头注意力实现Norm处理FeedForward以及连接编码器解码器中间层组装输出层模型组装总结前言Transform这玩意的大名我... 查看详情