图像分类convit从入门到实战——使用convit实现植物幼苗的分类(pytorch)(代码片段)

AI浩 AI浩     2022-12-13     277

关键词:

摘要

来自 Facebook 的研究者提出了一种名为 ConViT 的新计算机视觉模型,它结合了两种广泛使用的 AI 架构——卷积神经网络 (CNN) 和 Transformer,该模型取长补短,克服了 CNN 和 Transformer 本身的一些局限性。同时,借助这两种架构的优势,这种基于视觉 Transformer 的模型可以胜过现有架构,尤其是在小数据的情况下,同时在大数据的情况下也能实现类似的优秀性能。

• 论文地址:https://arxiv.org/pdf/2103.10697.pdf
• 论文翻译:ConViT:使用软卷积归纳偏置改进视觉变换器
• GitHub 地址:https://github.com/facebookresearch/convit
其工作原理
ConViT 在 vision Transformer 的基础上进行了调整,以利用 soft 卷积归纳偏置,从而激励网络进行卷积操作。同时最重要的是,ConViT 允许模型自行决定是否要保持卷积。为了利用这种 soft 归纳偏置,研究者引入了一种称为「门控位置自注意力(gated positional self-attention,GPSA)」的位置自注意力形式,其模型学习门控参数 lambda,该参数用于平衡基于内容的自注意力和卷积初始化位置自注意力。

如上图所示,ConViT(左)在 ViT 的基础上,将一些自注意力(SA)层用门控位置自注意力层(GPSA,右)替代。因为 GPSA 层涉及位置信息,因此在最后一个 GPSA层之后,类 token 会与隐藏表征联系到一起。
有了 GPSA 层加持,ConViT 的性能优于 Facebook 去年提出的 DeiT 模型。例如,ConViT-S+ 性能略优于 DeiT-B(对比结果为 82.2% vs. 81.8%),而 ConViT-S + 使用的参数量只有 DeiT-B 的一半左右 (48M vs 86M)。而 ConViT 最大的改进是在有限的数据范围内,soft 卷积归纳偏置发挥了重要作用。例如,仅使用 5% 的训练数据时,ConViT 的性能明显优于 DeiT(对比结果为 47.8% vs. 34.8%)。

此外,ConViT 在样本效率和参数效率方面也都优于 DeiT。如上图所示,左图为 ConViT-S 与 DeiT-S 的样本效率对比结果,这两个模型是在相同的超参数,且都是在 ImageNet-1k 的子集上训练完成的。图中绿色折线是 ConViT 相对于 DeiT 的提升。研究者还在 ImageNet-1k 上比较了 ConViT 模型与其他 ViT 以及 CNN 的 top-1 准确率,如上右图所示。
除了 ConViT 的性能优势外,门控参数提供了一种简单的方法来理解模型训练后每一层的卷积程度。查看所有层,研究者发现 ConViT 在训练过程中对卷积位置注意力的关注逐渐减少。对于靠后的层,门控参数最终会收敛到接近 0,这表明卷积归纳偏置实际上被忽略了。然而,对于起始层来说,许多注意力头保持较高的门控值,这表明该网络利用早期层的卷积归纳偏置来辅助训练。


上图展示了 DeiT (b) 及 ConViT © 注意力图的几个例子。σ(λ) 表示可学习的门控参数。接近 1 的值表示使用了卷积初始化,而接近 0 的值表示只使用了基于内容的注意力。注意,早期的 ConViT 层部分地维护了卷积初始化,而后面的层则完全基于内容。

导入项目使用的库

import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from models import convit_tiny
from torch.autograd import Variable
from mydatasets import SeedlingData
import os

设置全局参数

设置使用哪块GPU,0指的是第一块GPU,设置学习率、BatchSize、epoch等参数

os.environ['CUDA_VISIBLE_DEVICE'] = '0'

torch.cuda.set_device(0)
# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 16
EPOCHS = 100
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

图像预处理

数据处理比较简单,没有做复杂的尝试,有兴趣的可以加入一些处理。

# 数据预处理

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])
transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

读取数据

将数据集解压后放到data文件夹下面,如图:

然后我们在dataset文件夹下面新建 init.py和dataset.py,在mydatasets.py文件夹写入下面的代码:

说一下代码的核心逻辑。

第一步 建立字典,定义类别对应的ID,用数字代替类别。

第二步 在__init__里面编写获取图片路径的方法。测试集只有一层路径直接读取,训练集在train文件夹下面是类别文件夹,先获取到类别,再获取到具体的图片路径。然后使用sklearn中切分数据集的方法,按照7:3的比例切分训练集和验证集。

第三步 在__getitem__方法中定义读取单个图片和类别的方法,由于图像中有位深度32位的,所以我在读取图像的时候做了转换。

# coding:utf8
import os
from PIL import Image
from torch.utils import data
from torchvision import transforms as T
from sklearn.model_selection import train_test_split
 
Labels = 'Black-grass': 0, 'Charlock': 1, 'Cleavers': 2, 'Common Chickweed': 3,
          'Common wheat': 4, 'Fat Hen': 5, 'Loose Silky-bent': 6, 'Maize': 7, 'Scentless Mayweed': 8,
          'Shepherds Purse': 9, 'Small-flowered Cranesbill': 10, 'Sugar beet': 11
 
 
class SeedlingData (data.Dataset):
 
    def __init__(self, root, transforms=None, train=True, test=False):
        """
        主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
        """
        self.test = test
        self.transforms = transforms
 
        if self.test:
            imgs = [os.path.join(root, img) for img in os.listdir(root)]
            self.imgs = imgs
        else:
            imgs_labels = [os.path.join(root, img) for img in os.listdir(root)]
            imgs = []
            for imglable in imgs_labels:
                for imgname in os.listdir(imglable):
                    imgpath = os.path.join(imglable, imgname)
                    imgs.append(imgpath)
            trainval_files, val_files = train_test_split(imgs, test_size=0.3, random_state=42)
            if train:
                self.imgs = trainval_files
            else:
                self.imgs = val_files
 
    def __getitem__(self, index):
        """
        一次返回一张图片的数据
        """
        img_path = self.imgs[index]
        img_path=img_path.replace("\\\\",'/')
        if self.test:
            label = -1
        else:
            labelname = img_path.split('/')[-2]
            label = Labels[labelname]
        data = Image.open(img_path).convert('RGB')
        data = self.transforms(data)
        return data, label
 
    def __len__(self):
        return len(self.imgs)

然后我们在train.py调用SeedlingData读取数据 ,记着导入刚才写的dataset.py(from mydatasets import SeedlingData)

dataset_train = SeedlingData('data/train', transforms=transform, train=True)
dataset_test = SeedlingData("data/train", transforms=transform_test, train=False)
# 读取数据
print(dataset_train.imgs)

# 导入数据
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

设置模型

  • 设置loss函数为nn.CrossEntropyLoss()。
  • 设置模型为convit_tiny,num_classes设置为12,embed_dim是编码的个数,设置为48,dropout设置为0.5。
  • 优化器设置为adam。
# 实例化模型并且移动到GPU
criterion = nn.CrossEntropyLoss()
model_ft = convit_tiny(pretrained=False, num_classes=12, embed_dim=48, drop_rate=0.5)

model_ft.to(DEVICE)
# 选择简单暴力的Adam优化器,学习率调低
optimizer = optim.Adam(model_ft.parameters(), lr=modellr)


def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    modellrnew = modellr * (0.1 ** (epoch // 50))
    print("lr:", modellrnew)
    for param_group in optimizer.param_groups:
        param_group['lr'] = modellrnew

定义训练和验证函数

# 定义训练过程

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    sum_loss = 0
    total_num = len(train_loader.dataset)
    print(total_num, len(train_loader))
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data).to(device), Variable(target).to(device)
        output = model(data)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print_loss = loss.data.item()
        sum_loss += print_loss
        if (batch_idx + 1) % 10 == 0:
            print('Train Epoch:  [/ (:.0f%)]\\tLoss: :.6f'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                       100. * (batch_idx + 1) / len(train_loader), loss.item()))
    ave_loss = sum_loss / len(train_loader)
    print('epoch:,loss:'.format(epoch, ave_loss))


# 验证过程
def val(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    total_num = len(test_loader.dataset)
    print(total_num, len(test_loader))
    with torch.no_grad():
        for data, target in test_loader:
            data, target = Variable(data).to(device), Variable(target).to(device)
            output = model(data)
            loss = criterion(output, target)
            _, pred = torch.max(output.data, 1)
            correct += torch.sum(pred == target)
            print_loss = loss.data.item()
            test_loss += print_loss
        correct = correct.data.item()
        acc = correct / total_num
        avgloss = test_loss / len(test_loader)
        print('\\nVal set: Average loss: :.4f, Accuracy: / (:.0f%)\\n'.format(
            avgloss, correct, len(test_loader.dataset), 100 * acc))


# 训练

for epoch in range(1, EPOCHS + 1):
    adjust_learning_rate(optimizer, epoch)
    train(model_ft, DEVICE, train_loader, optimizer, epoch)
    val(model_ft, DEVICE, test_loader)
torch.save(model_ft, 'model.pth')

测试

我介绍两种常用的测试方式,第一种是通用的,通过自己手动加载数据集然后做预测,具体操作如下:

测试集存放的目录如下图:

第一步 定义类别,这个类别的顺序和训练时的类别顺序对应,一定不要改变顺序!!!!

第二步 定义transforms,transforms和验证集的transforms一样即可,别做数据增强。

第三步 加载model,并将模型放在DEVICE里,

第四步 读取图片并预测图片的类别,在这里注意,读取图片用PIL库的Image。不要用cv2,transforms不支持。

import torch.utils.data.distributed
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
import os
classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed',
           'Common wheat','Fat Hen', 'Loose Silky-bent',
           'Maize','Scentless Mayweed','Shepherds Purse','Small-flowered Cranesbill','Sugar beet')
transform_test = transforms.Compose([
         transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
 
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("model.pth")
model.eval()
model.to(DEVICE)
 
path='data/test/'
testList=os.listdir(path)
for file in testList:
        img=Image.open(path+file)
        img=transform_test(img)
        img.unsqueeze_(0)
        img = Variable(img).to(DEVICE)
        out=model(img)
        # Predict
        _, pred = torch.max(out.data, 1)
        print('Image Name:,predict:'.format(file,classes[pred.data.item()]))

第二种 使用自定义的Dataset读取图片

import torch.utils.data.distributed
import torchvision.transforms as transforms
from dataset.dataset import SeedlingData
from torch.autograd import Variable
 
classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed',
           'Common wheat','Fat Hen', 'Loose Silky-bent',
           'Maize','Scentless Mayweed','Shepherds Purse','Small-flowered Cranesbill','Sugar beet')
transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
 
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("model.pth")
model.eval()
model.to(DEVICE)
 
dataset_test =SeedlingData('data/test/', transform_test,test=True)
print(len(dataset_test))
# 对应文件夹的label
 
for index in range(len(dataset_test)):
    item = dataset_test[index]
    img, label = item
    img.unsqueeze_(0)
    data = Variable(img).to(DEVICE)
    output = model(data)
    _, pred = torch.max(output.data, 1)
    print('Image Name:,predict:'.format(dataset_test.imgs[index], classes[pred.data.item()]))
    index += 1
 

完整的训练代码

import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from dataset.dataset import SeedlingData
from torch.autograd import Variable
from torchvision.models import vgg16
 
# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 32
EPOCHS = 10
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
# 数据预处理
 
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
 
])
transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
dataset_train = SeedlingData('data/train', transforms=transform, train=True)
dataset_test = SeedlingData("data/train", transforms=transform_test, train=False)
# 读取数据
print(dataset_train.imgs)
 
# 导入数据
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
 
# 实例化模型并且移动到GPU
criterion = nn.CrossEntropyLoss()
model_ft = vgg16(pretrained=True)
model_ft.classifier = classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7,

图像分类实战:mobilenetv2从训练到tensorrt部署(pytorch)(代码片段)

...orRT项目结构训练数据增强Cutout和Mixup导入包设置全局参数图像预处理与增强读取数据设置模型定义训练和验证函数测试模型转化及推理转onnxonnx推理转TensorRTTensorRT推理摘要本例提取了植物幼苗数据集中的部分数据做数据集,... 查看详情

opencv|opencv实战从入门到精通系列二--opencv图像腐蚀

...第2篇OpenCV|OpenCV实战从入门到精通系列一--OpenCV宏的讲解图像腐蚀形态学转换主要针对的是二值图像(0或1)。图像腐蚀类似于“领域被蚕食”,将图像中的高亮区域或白色部分进行缩减细化,其运行结果图比原图... 查看详情

入门实战《深度学习技术图像处理入门》+《视觉slam十四讲从理论到实践》

学习图像识别处理,想在数据分析竞赛中取得较高的排名,看了《深度学习技术图像处理入门》电子书,一边看电子书一边做标记,对配套的代码也做了测试,收获颇多。 从机器学习、图像处理的基本概念入手,逐步阐述深... 查看详情

redis从入门到实战(一redis简介)

...​1.1、什么是NoSQL?​​​​1.2、NoSQL特点​​​​1.3、分类​​在学习redis之前,我们先学习一下NoSQL。1、NoSQL1.1、什么是NoSQL?NoSQL概念在2009 查看详情

opencv|opencv实战从入门到精通系列四--常用函数讲解

...--OpenCV宏的讲解OpenCV|OpenCV实战从入门到精通系列二--OpenCV图像腐蚀OpenCV|OpenCV实战从入门到精通系列三--can 查看详情

docker从入门到实战

...虚拟化、操作系统虚拟化等。1.1关于虚拟化虚拟化技术的分类与定义在不同领域有不同的理解。对于计算机领域,虚拟化技术主要分为两大类:一类基于硬件虚拟化,另一类基于软件虚拟化。硬件虚拟化并不 查看详情

spark从入门到上手实战

...进行批处理、流式处理、SQL交互式处理及机器学习和Graphx图像计算。目前 查看详情

图书推荐:kotlin从入门到进阶实战

图片发自简书App《Kotlin从入门到进阶实战》从Kotlin语言的基础语法讲起,逐步深入到Kotlin进阶实战,并在最后配合项目实战案例,重点介绍了使用Kotlin+SpringBoot进行服务端开发和使用Kotlin进行Android应用程序开发的内容,让读者不... 查看详情

opencv|opencv实战从入门到精通系列一--opencv宏的讲解

...星标 ★,与你不见不散本文来源学习笔记文章目录图像处理计算机视觉OpenCV网页OpenCV可应用的领域OpenCV模块按宏定义顺序介绍图像处理图像处理技术一般包括图像压缩,增强和复原,匹配、描述和识别这3部分。数字... 查看详情

推荐系统从入门到实战——flask框架的使用

Flask框架的使用​​Flask框架的使用​​​​Flask简介​​​​Flask环境配置​​​​安装virtualenv​​​​创建虚拟环境​​​​激活环境​​​​安装包​​​​测试安装​​​​主要内容​​​​路由​​​​route装饰器​​... 查看详情

《自然语言处理实战入门》文本分类----使用textrnn进行文本分类

...种,如LSTM(更常用),GRU。当然我们也可以把RNN运用到文本分类任务中。这里的文本可以一个句子,文档(短文本,若干句子)或篇章(长文本),因此每段文本的长度都不尽相同。在对文本进行分类时,我们一般会指定一个固定的输... 查看详情

javacv入门教程目录(javacv从入门到实战,javacv指南手册,免费javacv教程)

...章单独创建了《JavaCV入门教程》目录,供大家查看。分类目录​JavaCV入门目录:JavaCV入门指南࿱ 查看详情

javacv入门教程目录(javacv从入门到实战,javacv指南手册,免费javacv教程)

...章单独创建了《JavaCV入门教程》目录,供大家查看。分类目录​JavaCV入门目录:JavaCV入门指南࿱ 查看详情

protobuf从入门到实战

简介     从第一次接触Protobuf到实际使用已经有半年多,刚开始可能被它的名字所唬住,其实就它是一种轻便高效的数据格式,平台无关、语言无关、可扩展,可用于通讯协议和数据存储等领域。优点平台无关... 查看详情

docker-9supervisord参考docker从入门到实战

参考docker从入门到实战使用Supervisor来管理进程Docker容器在启动的时候开启单个进程,比如,一个ssh或者apache的daemon服务。但我们经常需要在一个机器上开启多个服务,这可以有很多方法,最简单的就是把多个启动命令放到一个... 查看详情

protobuf从入门到实战(代码片段)

简介     从第一次接触Protobuf到实际使用已经有半年多,刚开始可能被它的名字所唬住,其实就它是一种轻便高效的数据格式,平台无关、语言无关、可扩展,可用于通讯协议和数据存储等领域。 优点平台... 查看详情

《javacv从入门到实战教程合集》介绍和目录

...介绍和目录《JavaCV入门教程》详细介绍了音视频流媒体、图像处理识别等技术的前置知识,JavaCV的基础结构明细以及JavaCV各个结构的说明和用法。通过配合JavaCV实战教程中的实例带领大家全面理解JavaCV。JavaCV入门指南:... 查看详情

convit:使用软卷积归纳偏置改进视觉变换器

...变换器(ViT)依赖于更灵活的自注意力层,并且最近在图像分类方面的表现优于CNN。然而,它们需要对大型外部数据集进行昂贵的预训练或从预训练的卷积网络中提取。在本文中,我们提出以下问题:是否可以结合... 查看详情