知识蒸馏实战:使用coatnet蒸馏resnet(代码片段)

AI浩 AI浩     2022-12-05     598

关键词:

文章目录

摘要

知识蒸馏(Knowledge Distillation),简称KD,将已经训练好的模型包含的知识(”Knowledge”),蒸馏(“Distill”)提取到另一个模型里面去。Hinton在"Distilling the Knowledge in a Neural Network"首次提出了知识蒸馏(暗知识提取)的概念,通过引入与教师网络(Teacher network:复杂、但预测精度优越)相关的软目标(Soft-target)作为Total loss的一部分,以诱导学生网络(Student network:精简、低复杂度,更适合推理部署)的训练,实现知识迁移(Knowledge transfer)。论文链接:https://arxiv.org/pdf/1503.02531.pdf

蒸馏的过程

知识蒸馏使用的是Teacher—Student模型,其中teacher是“知识”的输出者,student是“知识”的接受者。知识蒸馏的过程分为2个阶段:

  • 原始模型训练: 训练"Teacher模型", 简称为Net-T,它的特点是模型相对复杂,也可以由多个分别训练的模型集成而成。我们对"Teacher模型"不作任何关于模型架构、参数量、是否集成方面的限制,唯一的要求就是,对于输入X, 其都能输出Y,其中Y经过softmax的映射,输出值对应相应类别的概率值。
  • 精简模型训练: 训练"Student模型", 简称为Net-S,它是参数量较小、模型结构相对简单的单模型。同样的,对于输入X,其都能输出Y,Y经过softmax映射后同样能输出对应相应类别的概率值。
  • Teacher学习能力强,可以将它学到的知识迁移给学习能力相对弱的Student模型,以此来增强Student模型的泛化能力。复杂笨重但是效果好的Teacher模型不上线,就单纯是个导师角色,真正部署上线进行预测任务的是灵活轻巧的Student小模型。

最终结论

先把结论说了吧! Teacher网络使用coatnet_2,Student网络使用ResNet18。如下表

网络epochsACC
coatnet_25092%
ResNet185086%
ResNet18 +KD5089%

在相同的条件下,加入知识蒸馏后,ResNet18的ACC上升了3个点,提升的还是很高的。如下图:

数据准备

数据使用我以前在图像分类任务中的数据集——植物幼苗数据集,先将数据集转为训练集和验证集。执行代码:

import glob
import os
import shutil

image_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):
    print('true')
    #os.rmdir(file_dir)
    shutil.rmtree(file_dir)#删除再建立
    os.makedirs(file_dir)
else:
    os.makedirs(file_dir)

from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:
    file_class=file.replace("\\\\","/").split('/')[-2]
    file_name=file.replace("\\\\","/").split('/')[-1]
    file_class=os.path.join(train_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

for file in val_files:
    file_class=file.replace("\\\\","/").split('/')[-2]
    file_name=file.replace("\\\\","/").split('/')[-1]
    file_class=os.path.join(val_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

教师网络

教师网络选用coatnet_2,是一个比较大一点的网络了,模型的大小有200M。训练50个epoch,最好的模型在92%左右。

步骤

新建teacher_train.py,插入代码:

导入需要的库

import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torchvision import datasets
from torch.autograd import Variable
from model.coatnet import coatnet_2

import json
import os

定义训练和验证函数


def train(model, device, train_loader, optimizer, epoch):
    model.train()
    sum_loss = 0
    total_num = len(train_loader.dataset)
    print(total_num, len(train_loader))
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data).to(device), Variable(target).to(device)
        output = model(data)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print_loss = loss.data.item()
        sum_loss += print_loss
        if (batch_idx + 1) % 10 == 0:
            print('Train Epoch:  [/ (:.0f%)]\\tLoss: :.6f'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                       100. * (batch_idx + 1) / len(train_loader), loss.item()))
    ave_loss = sum_loss / len(train_loader)
    print('epoch:,loss:'.format(epoch, ave_loss))

Best_ACC=0
# 验证过程
@torch.no_grad()
def val(model, device, test_loader):
    global Best_ACC
    model.eval()
    test_loss = 0
    correct = 0
    total_num = len(test_loader.dataset)
    print(total_num, len(test_loader))
    with torch.no_grad():
        for data, target in test_loader:
            data, target = Variable(data).to(device), Variable(target).to(device)
            output = model(data)
            loss = criterion(output, target)
            _, pred = torch.max(output.data, 1)
            correct += torch.sum(pred == target)
            print_loss = loss.data.item()
            test_loss += print_loss
        correct = correct.data.item()
        acc = correct / total_num
        avgloss = test_loss / len(test_loader)
        if acc > Best_ACC:
            torch.save(model, file_dir + '/' + 'best.pth')
            Best_ACC = acc
        print('\\nVal set: Average loss: :.4f, Accuracy: / (:.0f%)\\n'.format(
            avgloss, correct, len(test_loader.dataset), 100 * acc))
        return acc

定义全局参数

if __name__ == '__main__':
    # 创建保存模型的文件夹
    file_dir = 'CoatNet'
    if os.path.exists(file_dir):
        print('true')

        os.makedirs(file_dir, exist_ok=True)
    else:
        os.makedirs(file_dir)

    # 设置全局参数
    modellr = 1e-4
    BATCH_SIZE = 16
    EPOCHS = 50
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

图像预处理与增强

 # 数据预处理7
    transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])

    ])
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
    ])


读取数据

使用pytorch默认读取数据的方式。

    # 读取数据
    dataset_train = datasets.ImageFolder('data/train', transform=transform)
    dataset_test = datasets.ImageFolder("data/val", transform=transform_test)
    with open('class.txt', 'w') as file:
        file.write(str(dataset_train.class_to_idx))
    with open('class.json', 'w', encoding='utf-8') as file:
        file.write(json.dumps(dataset_train.class_to_idx))
    # 导入数据
    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

设置模型和Loss

   # 实例化模型并且移动到GPU
    criterion = nn.CrossEntropyLoss()

    model_ft = coatnet_2()
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, 12)
    model_ft.to(DEVICE)
    # 选择简单暴力的Adam优化器,学习率调低
    optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
    cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)
    # 训练
    val_acc_list= 
    for epoch in range(1, EPOCHS + 1):
        train(model_ft, DEVICE, train_loader, optimizer, epoch)
        cosine_schedule.step()
        acc=val(model_ft, DEVICE, test_loader)
        val_acc_list[epoch]=acc
        with open('result.json', 'w', encoding='utf-8') as file:
            file.write(json.dumps(val_acc_list))
    torch.save(model_ft, 'CoatNet/model_final.pth')

完成上面的代码就可以开始训练Teacher网络了。

学生网络

学生网络选用ResNet18,是一个比较小一点的网络了,模型的大小有40M。训练50个epoch,最好的模型在86%左右。

步骤

新建student_train.py,插入代码:

导入需要的库

import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torchvision import datasets
from torch.autograd import Variable
from torchvision.models.resnet import resnet18

import json
import os

定义训练和验证函数


def train(model, device, train_loader, optimizer, epoch):
    model.train()
    sum_loss = 0
    total_num = len(train_loader.dataset)
    print(total_num, len(train_loader))
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data).to(device), Variable(target).to(device)
        output = model(data)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print_loss = loss.data.item()
        sum_loss += print_loss
        if (batch_idx + 1) % 10 == 0:
            print('Train Epoch:  [/ (:.0f%)]\\tLoss: :.6f'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                       100. * (batch_idx + 1) / len(train_loader), loss.item()))
    ave_loss = sum_loss / len(train_loader)
    print('epoch:,loss:'.format(epoch, ave_loss))

Best_ACC=0
# 验证过程
@torch.no_grad()
def val(model, device, test_loader):
    global Best_ACC
    model.eval()
    test_loss = 0
    correct = 0
    total_num = len(test_loader.dataset)
    print(total_num, len(test_loader))
    with torch.no_grad():
        for data, target in test_loader:
            data, target = Variable(data).to(device), Variable(target).to(device)
            output = model(data)
            loss = criterion(output, target)
            _, pred = torch.max(output.data, 1)
            correct += torch.sum(pred == target)
            print_loss = loss.data.item()
            test_loss += print_loss
        correct = correct.data.item()
        acc = correct / total_num
        avgloss = test_loss / len(test_loader)
        if acc > Best_ACC:
            torch.save(model, file_dir + '/' + 'best.pth')
            Best_ACC = acc
        print('\\nVal set: Average loss: :.4f, Accuracy: / (:.0f%)\\n'.format(
            avgloss, correct, len(test_loader.dataset), 100 * acc))
        return acc

定义全局参数

if __name__ == '__main__':
    # 创建保存模型的文件夹
    file_dir = 'resnet'
    if os.path.exists(file_dir):
        print('true')

        os.makedirs(file_dir, exist_ok=True)
    else:
        os.makedirs(file_dir)

    # 设置全局参数
    modellr = 1e-4
    BATCH_SIZE = 16
    EPOCHS = 50
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

图像预处理与增强

 # 数据预处理7
    transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])

    ])
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
    ])


读取数据

使用pytorch默认读取数据的方式。

    # 读取数据
    dataset_train = datasets.ImageFolder('data/train', transform=transform)
    dataset_test = datasets.ImageFolder("data/val", transform=transform_test)
    with open('class.txt', 'w') as file:
        file.write(str知识蒸馏nst算法实战:使用coatnet蒸馏resnet18(代码片段)

文章目录摘要最终结论模型ResNet18,ResNet34CoatNet数据准备训练Teacher模型步骤导入需要的库定义训练和验证函数定义全局参数图像预处理与增强读取数据设置模型和Loss学生网络步骤导入需要的库定义训练和验证函数定义全局参... 查看详情

知识蒸馏nst算法实战:使用coatnet蒸馏resnet18(代码片段)

文章目录摘要最终结论模型ResNet18,ResNet34CoatNet数据准备训练Teacher模型步骤导入需要的库定义训练和验证函数定义全局参数图像预处理与增强读取数据设置模型和Loss学生网络步骤导入需要的库定义训练和验证函数定义全局参... 查看详情

知识蒸馏irg算法实战:使用resnet50蒸馏resnet18(代码片段)

...用。模型压缩方法帮助模型在效率和精度之间进行折中。知识蒸馏是模型压缩的一种有效手段,它的核心思想是迫使轻量级的学生模型去学习教师模型提取到的知识,从而提高学生模型的性能。已有的知识蒸馏方法可以... 查看详情

知识蒸馏irg算法实战:使用resnet50蒸馏resnet18(代码片段)

...用。模型压缩方法帮助模型在效率和精度之间进行折中。知识蒸馏是模型压缩的一种有效手段,它的核心思想是迫使轻量级的学生模型去学习教师模型提取到的知识,从而提高学生模型的性能。已有的知识蒸馏方法可以... 查看详情

知识蒸馏deit算法实战:使用regnet蒸馏deit模型(代码片段)

文章目录摘要最终结论项目结构模型和lossmodel.py代码losses.py代码训练Teacher模型步骤导入需要的库定义训练和验证函数定义全局参数图像预处理与增强读取数据设置模型和Loss学生网络步骤导入需要的库定义训练和验证函数定义全... 查看详情

yolov5/v7进阶实战|目录|安卓|pyqt5|剪枝✂️|蒸馏⚗️|flaskweb|改进教程

...LOv5剪枝|模型剪枝理论篇YOLOv5剪枝💖|模型剪枝实战篇知识蒸馏|知识蒸馏理论篇知识蒸馏🌟|YOLOv5知识蒸馏实战篇知识蒸馏🌟|YOLOv7知识蒸馏实战篇YOLOv5安卓部署📱|理论+环境配置+实战PyQt5|PyQt5环境配置及组... 查看详情

知识蒸馏是不是具有整体效应?

】知识蒸馏是不是具有整体效应?【英文标题】:Doesknowledgedistillationhaveanensembleeffect?知识蒸馏是否具有整体效应?【发布时间】:2021-09-2313:22:21【问题描述】:我对知识蒸馏知之甚少。我有一个问题。有一个模型表现出99%的性... 查看详情

知识蒸馏算法汇总(代码片段)

知识蒸馏有两大类:一类是logits蒸馏,另一类是特征蒸馏。logits蒸馏指的是在softmax时使用较高的温度系数,提升负标签的信息,然后使用Student和Teacher在高温softmax下logits的KL散度作为loss。中间特征蒸馏就是强迫Stu... 查看详情

知识蒸馏之自蒸馏

知识蒸馏之自蒸馏@TOC知识蒸馏之自蒸馏本文整理了近几年顶会中的蒸馏类文章(强调self-distillation),后续可能会继续更新其他计算机视觉领域顶会中的相关工作,欢迎各位伙伴相互探讨。背景知识-注意力蒸馏、自... 查看详情

一文搞懂知识蒸馏knowledgedistillation算法原理(代码片段)

知识蒸馏算法原理精讲文章目录知识蒸馏算法原理精讲1.什么是知识蒸馏?2.轻量化网络的方式有哪些?3.为什么要进行知识蒸馏?3.1提升模型精度3.2降低模型时延,压缩网络参数3.3标签之间的域迁移4.知识蒸馏的... 查看详情

目标检测yolov5遇上知识蒸馏(代码片段)

...pruning)稀疏表示(Sparserepresentation)模型量化(Modelquantification)知识蒸馏(Konwledgedistillation)本文主要来研究知识蒸馏的相关知识,并尝试用知识蒸馏的方法对YOLOv5进行改进。知识蒸馏理论简介概述知识蒸馏(KnowledgeDistillatio 查看详情

模型加速知识蒸馏

 现状知识蒸馏核心思想细节补充   知识蒸馏的思想最早是由Hinton大神在15年提出的一个黑科技,Hinton在一些报告中将该技术称之为DarkKnowledge,技术上一般叫做知识蒸馏(KnowledgeDistillation),是模型加速中的一种重要的... 查看详情

知识蒸馏:distillingtheknowledgeinaneuralnetwork

文章目录摘要1简介2蒸馏2.1匹配逻辑是蒸馏的特殊情况3MNIST的初步实验4语音识别实验4.1结果5在大的数据集上训练专家模型5.1JFT数据集5.2专家模型5.3为专家分配类5.4使用专家集合进行推理5.5结果6软目标作为正则化器6.1使用软目标... 查看详情

神经网络分类知识蒸馏

参考链接:知识蒸馏是什么?一份入门随笔https://zhuanlan.zhihu.com/p/90049906[论文阅读]知识蒸馏(DistillingtheKnowledgeinaNeuralNetwork)https://blog.csdn.net/ZY_miao/article/details/110182948DistillingtheKnowledgeinaNeuralNetwork[ 查看详情

知识蒸馏基本原理

...蒸馏的目的是在不同的温度下提出了特定的成分。说回到知识蒸馏ÿ 查看详情

知识蒸馏基本原理

...蒸馏的目的是在不同的温度下提出了特定的成分。说回到知识蒸馏ÿ 查看详情

知识蒸馏基本原理

...蒸馏的目的是在不同的温度下提出了特定的成分。说回到知识蒸馏ÿ 查看详情

Tensorflow 2 + Keras 的知识蒸馏损失

】Tensorflow2+Keras的知识蒸馏损失【英文标题】:KnowledgeDistillationlosswithTensorflow2+Keras【发布时间】:2020-03-2700:08:54【问题描述】:我正在尝试实现一个非常简单的keras模型,该模型使用来自另一个模型的知识蒸馏[1]。粗略地说,我... 查看详情