聊一聊计算机视觉中常用的注意力机制附pytorch代码实现(代码片段)

肆十二 肆十二     2023-03-01     796

关键词:

聊一聊计算机视觉中常用的注意力机制以及Pytorch代码实现

注意力机制(Attention)是深度学习中常用的tricks,可以在模型原有的基础上直接插入,进一步增强你模型的性能。注意力机制起初是作为自然语言处理中的工作Attention Is All You Need被大家所熟知,从而也引发了一系列的XX is All You Need的论文命题,SENET-Squeeze-and-Excitation Networks是注意力机制在计算机视觉中应用的早期工作之一,并获得了2017年imagenet, 同时也是最后一届Imagenet比赛的冠军,后面就又出现了各种各样的注意力机制,应用在计算机视觉的任务中,今天我们就来一起聊一聊计算机视觉中常用的注意力机制以及他们对应的Pytorch代码实现,另外我还使用这些注意力机制做了一些目标检测的实验,实验效果我也一并放在博客中,大家可以一起对自己感兴趣的部分讨论讨论。

新出的手把手教程,感兴趣的兄弟们快去自己动手试试看!

手把手教你使用YOLOV5训练自己的目标检测模型-口罩检测-视频教程_dejahu的博客-CSDN博客

这里是我数据集的基本情况,这里我使用的是交通标志检测的数据集

CocoDataset Train dataset with number of images 2226, and instance counts: 
+------------+-------+-----------+-------+-----------+-------+-----------------------------+-------+---------------------+-------+
| category   | count | category  | count | category  | count | category                    | count | category            | count |
+------------+-------+-----------+-------+-----------+-------+-----------------------------+-------+---------------------+-------+
| 0 [red_tl] | 1465  | 1 [arr_s] | 1133  | 2 [arr_l] | 638   | 3 [no_driving_mark_allsort] | 622   | 4 [no_parking_mark] | 1142  |
+------------+-------+-----------+-------+-----------+-------+-----------------------------+-------+---------------------+-------+

baseline选择的是fasterrcnn,实验的结果如下:

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.341
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=1000 ] = 0.502
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=1000 ] = 0.400
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.115
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.473
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.655
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.417
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=300 ] = 0.417
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=1000 ] = 0.417
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.156
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.570
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.726

如果大家遇到论文下载比较慢

推荐使用中科院的 arxiv 镜像: http://xxx.itp.ac.cn, 国内网络能流畅访问
简单直接的方法是, 把要访问 arxiv 链接中的域名从 https://arxiv.org 换成 http://xxx.itp.ac.cn , 比如:

https://arxiv.org/abs/1901.07249 改为 http://xxx.itp.ac.cn/abs/1901.07249

1. SeNet: Squeeze-and-Excitation Attention

论文地址:https://arxiv.org/abs/1709.01507

  • 网络结构

    对通道做注意力机制,通过全连接层对每个通道进行加权。

  • Pytorch代码

    import numpy as np
    import torch
    from torch import nn
    from torch.nn import init
    
    
    class SEAttention(nn.Module):
    
        def __init__(self, channel=512, reduction=16):
            super().__init__()
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
            self.fc = nn.Sequential(
                nn.Linear(channel, channel // reduction, bias=False),
                nn.ReLU(inplace=True),
                nn.Linear(channel // reduction, channel, bias=False),
                nn.Sigmoid()
            )
    
        def init_weights(self):
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    init.kaiming_normal_(m.weight, mode='fan_out')
                    if m.bias is not None:
                        init.constant_(m.bias, 0)
                elif isinstance(m, nn.BatchNorm2d):
                    init.constant_(m.weight, 1)
                    init.constant_(m.bias, 0)
                elif isinstance(m, nn.Linear):
                    init.normal_(m.weight, std=0.001)
                    if m.bias is not None:
                        init.constant_(m.bias, 0)
    
        def forward(self, x):
            b, c, _, _ = x.size()
            y = self.avg_pool(x).view(b, c)
            y = self.fc(y).view(b, c, 1, 1)
            return x * y.expand_as(x)
    
    
    if __name__ == '__main__':
        input = torch.randn(50, 512, 7, 7)
        se = SEAttention(channel=512, reduction=8)
        output = se(input)
        print(output.shape)
    
    
  • 实验结果

     Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.338
     Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=1000 ] = 0.511
     Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=1000 ] = 0.375
     Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.126
     Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.458
     Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.696
     Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.411
     Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=300 ] = 0.411
     Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=1000 ] = 0.411
     Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.163
     Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.551
     Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.758
    

2. (有用)CBAM: Convolutional Block Attention Module

论文地址:CBAM: Convolutional Block Attention Module

  • 网络结构

    对通道方向上做注意力机制之后再对空间方向上做注意力机制

  • Pytorch代码

    import numpy as np
    import torch
    from torch import nn
    from torch.nn import init
    
    
    class ChannelAttention(nn.Module):
        def __init__(self, channel, reduction=16):
            super().__init__()
            self.maxpool = nn.AdaptiveMaxPool2d(1)
            self.avgpool = nn.AdaptiveAvgPool2d(1)
            self.se = nn.Sequential(
                nn.Conv2d(channel, channel // reduction, 1, bias=False),
                nn.ReLU(),
                nn.Conv2d(channel // reduction, channel, 1, bias=False)
            )
            self.sigmoid = nn.Sigmoid()
    
        def forward(self, x):
            max_result = self.maxpool(x)
            avg_result = self.avgpool(x)
            max_out = self.se(max_result)
            avg_out = self.se(avg_result)
            output = self.sigmoid(max_out + avg_out)
            return output
    
    
    class SpatialAttention(nn.Module):
        def __init__(self, kernel_size=7):
            super().__init__()
            self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2)
            self.sigmoid = nn.Sigmoid()
    
        def forward(self, x):
            max_result, _ = torch.max(x, dim=1, keepdim=True)
            avg_result = torch.mean(x, dim=1, keepdim=True)
            result = torch.cat([max_result, avg_result], 1)
            output = self.conv(result)
            output = self.sigmoid(output)
            return output
    
    
    class CBAMBlock(nn.Module):
    
        def __init__(self, channel=512, reduction=16, kernel_size=49):
            super().__init__()
            self.ca = ChannelAttention(channel=channel, reduction=reduction)
            self.sa = SpatialAttention(kernel_size=kernel_size)
    
        def init_weights(self):
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    init.kaiming_normal_(m.weight, mode='fan_out')
                    if m.bias is not None:
                        init.constant_(m.bias, 0)
                elif isinstance(m, nn.BatchNorm2d):
                    init.constant_(m.weight, 1)
                    init.constant_(m.bias, 0)
                elif isinstance(m, nn.Linear):
                    init.normal_(m.weight, std=0.001)
                    if m.bias is not None:
                        init.constant_(m.bias, 0)
    
        def forward(self, x):
            b, c, _, _ = x.size()
            residual = x
            out = x * self.ca(x)
            out = out * self.sa(out)
            return out + residual
    
    
    if __name__ == '__main__':
        input = torch.randn(50, 512, 7, 7)
        kernel_size = input.shape[2]
        cbam = CBAMBlock(channel=512, reduction=16, kernel_size=kernel_size)
        output = cbam(input)
        print(output.shape)
    
    
  • 实验结果

     Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.364
     Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=1000 ] = 0.544
     Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=1000 ] = 0.425
     Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.137
     Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.499
     Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.674
     Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.439
     Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=300 ] = 0.439
     Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=1000 ] = 0.439
     Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.185
     Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.590
     Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.755
    

3. BAM: Bottleneck Attention Module

论文地址:https://arxiv.org/pdf/1807.06514.pdf

  • 网络结构

  • Pytorch代码

    import numpy as np
    import torch
    from torch import nn
    from torch.nn import init
    
    
    class Flatten(nn.Module):
        def forward(self, x):
            return x.view(x.shape[0], -1)
    
    
    class ChannelAttention(nn.Module):
        def __init__(self, channel, reduction=16, num_layers=3):
            super().__init__()
            self.avgpool = nn.AdaptiveAvgPool2d(1)
            gate_channels = [channel]
            gate_channels += [channel // reduction] * num_layers
            gate_channels += [channel]
    
            self.ca = nn.Sequential()
            self.ca.add_module('flatten', Flatten())
            for i in range(len(gate_channels) - 2):
                self.ca.add_module('fc%d' % i, nn.Linear(gate_channels[i], gate_channels[i + 1]))
                self.ca.add_module('bn%d' % i, nn.BatchNorm1d(gate_channels[i + 1]))
                self.ca.add_module('relu%d' % i, nn.ReLU())
            self.ca.add_module(综述|计算机视觉中的注意力机制
    

    ↑点击蓝字 关注AI派作者丨HUST小菜鸡@知乎来源丨https://zhuanlan.zhihu.com/p/146130215编辑丨极市平台导读 Attention是一种资源分配的机制,在视觉方面,其核心思想是突出对象的某些重要特征。本文主要介绍了多种注意力机... 查看详情

    计算机视觉中的注意力机制研究

    ...交上去就完了,所以发上来留个纪念。将注意力机制用在计算机视觉任务上,可以有效捕捉图片中有用的区域,从而提升整体网络性能。计算机视觉领域的注意力机制主要分为两类:(1)self-attention;(2)scaleattention。这两类注意力... 查看详情

    聊一聊android的消息机制(代码片段)

    聊一聊Android的消息机制侯亮1概述在Android平台上,主要用到两种通信机制,即Binder机制和消息机制,前者用于跨进程通信,后者用于进程内部通信。从技术实现上来说,消息机制还是比较简单的。从大的方面... 查看详情

    聊一聊vue实例与生命周期运行机制

    Vue的实例是Vue框架的入口,担任MVVM中的ViewModel角色,所有功能的实现都是围绕其生命周期进行的,在生命周期的不同阶段调用对应的钩子函数可以实现组件数据管理和DOM渲染两大重要功能。例如,实例需要配置数据观测(dataobserv... 查看详情

    透过现象看本质——聊一聊docker的硬件资源控制与验证(代码片段)

    透过现象看本质——聊一聊docker的硬件资源控制与验证前言?前面两篇文章主要介绍了有关docker的基础概念、安装、以及对镜像容器的相关操作。重点在于命令的含义以及常用的一些命令的可选项的含义的理解,本文在此基础上... 查看详情

    深度学习中一些注意力机制的介绍以及pytorch代码实现(代码片段)

    文章目录前言注意力机制软注意力机制代码实现硬注意力机制多头注意力机制代码实现参考前言因为最近看论文发现同一个模型用了不同的注意力机制计算方法,因此懵了好久,原来注意力机制也是多种多样的,为了... 查看详情

    清华&南开出品最新视觉注意力机制attention综述

    ...、卡迪夫大学RalphR.Martin教授合作,在ArXiv上发布关于计算机视觉中的注意力机制的综述文章。该综述系统地介绍了注意力机制在计算机视觉领域中相关工作清华计图胡事民团队的这篇注意力机制的综述火了!在上周的arXiv... 查看详情

    聊一聊顺序消息(rocketmq顺序消息的实现机制)(代码片段)

    当我们说顺序时,我们在说什么?日常思维中,顺序大部分情况会和时间关联起来,即时间的先后表示事件的顺序关系。比如事件A发生在下午3点一刻,而事件B发生在下午4点,那么我们认为事件A发生在事件B之前,他们的顺序关... 查看详情

    聊一聊编程字体--clwu

     正确状态(220-221行,VIM): 不正确状态(220-221行,PhpStorm):  在VIM中的用的是等宽字体(console),注意不是GVIM,在terminal终端中使用VIM,VIM中的字体属性使用的就是terminal终端中的属性,   在 Ph... 查看详情

    聊一聊spark写文件的机制——如何保证数据一致性(代码片段)

      聊这个问题的原因是,本周在测试环境遇到了一例从Spark往S3写数据失败的情况,花了些时间来搞清楚个中缘由,这里整理出来与大家分享,期望能对同道中人有所帮助。背景  在笔者的数据系统中,每... 查看详情

    聊一聊spark写文件的机制——如何保证数据一致性(代码片段)

      聊这个问题的原因是,本周在测试环境遇到了一例从Spark往S3写数据失败的情况,花了些时间来搞清楚个中缘由,这里整理出来与大家分享,期望能对同道中人有所帮助。背景  在笔者的数据系统中,每... 查看详情

    游戏货不对版怎么办?聊一聊游戏的退款机制

    650)this.width=650;"src="https://img.mp.itc.cn/upload/20170424/73491a7f51614507a4ebb5f6ccb45b75_th.png"alt="73491a7f51614507a4ebb5f6ccb45b75_th.png"/>   PC、智能手机、次时代主机等终端设备的全面普及,让游戏成为当下普 查看详情

    动手学pytorch-注意力机制和seq2seq模型(代码片段)

    注意力机制和Seq2Seq模型1.基本概念2.两种常用的attention层3.带注意力机制的Seq2Seq模型4.实验1.基本概念Attention是一种通用的带权池化方法,输入由两部分构成:询问(query)和键值对(key-valuepairs)。(??_??∈?^??_??,??_??∈?^??_??).Query(?... 查看详情

    杂谈:聊一聊nio(代码片段)

    ...,也可以说是Java中的一个包(NIO包是在java1.4中引入的).IO是计算机与其他部分之间的接口,IO可以分为多种:BIO,NIO,AIONIO模型2.2NIO,AIO,BIO的区别BIO:基于流(Stre 查看详情

    视觉注意力机制——通道注意力空间注意力自注意力

    前言本文介绍注意力机制的概念和基本原理,并站在计算机视觉CV角度,进一步介绍通道注意力、空间注意力、混合注意力、自注意力等。目录前言一、注意力机制二、通道注意力机制三、空间注意力机制四、混合注意力... 查看详情

    视觉注意力机制——通道注意力空间注意力自注意力

    前言本文介绍注意力机制的概念和基本原理,并站在计算机视觉CV角度,进一步介绍通道注意力、空间注意力、混合注意力、自注意力等。目录前言一、注意力机制二、通道注意力机制三、空间注意力机制四、混合注意力... 查看详情

    计算机视觉pytorch实现(代码片段)

    计算机视觉PyTorch实现(一)PyTorch基础模块计算机视觉可以被广泛应用于多个现实领域中。如做图像基本处理、图像识别、图像分割、目标跟踪、图像分类、姿态估计等。在深度学习中人们开发了很多的学习框架,如C... 查看详情

    你期待已久的《动手学深度学习》(pytorch版)来啦!

    ...度学习计算性能的重要因素,并分别列举深度学习在计算机视觉 查看详情