小白学习pytorch教程十六在多标签分类任务上微调bert模型(代码片段)

刘润森! 刘润森!     2023-03-09     796

关键词:

@Author:Runsen

BERT模型在NLP各项任务中大杀四方,那么我们如何使用这一利器来为我们日常的NLP任务来服务呢?首先介绍使用BERT做文本多标签分类任务。

文本多标签分类是常见的NLP任务,文本介绍了如何使用Bert模型完成文本多标签分类,并给出了各自的步骤。

参考官方教程:https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html

复旦大学邱锡鹏老师课题组的研究论文《How to Fine-Tune BERT for Text Classification?》。

论文: https://arxiv.org/pdf/1905.05583.pdf

这篇论文的主要目的在于在文本分类任务上探索不同的BERT微调方法并提供一种通用的BERT微调解决方法。这篇论文从三种路线进行了探索:

  • (1) BERT自身的微调策略,包括长文本处理、学习率、不同层的选择等方法;
  • (2) 目标任务内、领域内及跨领域的进一步预训练BERT;
  • (3) 多任务学习。微调后的BERT在七个英文数据集及搜狗中文数据集上取得了当前最优的结果。

作者的实现代码: https://github.com/xuyige/BERT4doc-Classification

数据集来源:https://www.kaggle.com/shivanandmn/multilabel-classification-dataset?select=train.csv

该数据集包含 6 个不同的标签(计算机科学、物理、数学、统计学、生物学、金融),以根据摘要和标题对研究论文进行分类。
标签列中的值 1 表示标签属于该标签。每个论文有多个标签为 1。

Bert模型加载

Transformer 为我们提供了一个基于 Transformer 的可以微调的预训练网络。

由于数据集是英文, 因此这里选择加载bert-base-uncased。

具体下载链接:https://huggingface.co/bert-base-uncased/tree/main

from transformers import BertTokenizerFast as BertTokenizer
# 直接下载很很慢,建议下载到文件夹中
# BERT_MODEL_NAME = "bert-base-uncased"
BERT_MODEL_NAME = "model/bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)

微调BERT模型

bert微调就是在预训练模型bert的基础上只需更新后面几层的参数,这相对于从头开始训练可以节省大量时间,甚至可以提高性能,通常情况下在模型的训练过程中,我们也会更新bert的参数,这样模型的性能会更好。

微调BERT模型主要在D_out进行相关的改变,去除segment层,直接采用了字符输入,不再需要segment层。

下面是微调BERT的主要代码

class BertClassifier(nn.Module):
    def __init__(self, num_labels: int, BERT_MODEL_NAME, freeze_bert=False):
        super().__init__()
        self.num_labels = num_labels
        self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)

        #  hidden size of BERT, hidden size of our classifier, and number of labels to classify
        D_in, H, D_out = self.bert.config.hidden_size, 50, num_labels

        # Instantiate an one-layer feed-forward classifier
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.3),
            nn.Linear(D_in, H),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(H, D_out),
        )
        # loss
        self.loss_func = nn.BCEWithLogitsLoss()

        if freeze_bert:
            print("freezing bert parameters")
            for param in self.bert.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)

        last_hidden_state_cls = outputs[0][:, 0, :]

        logits = self.classifier(last_hidden_state_cls)

        if labels is not None:
            predictions = torch.sigmoid(logits)
            loss = self.loss_func(
                predictions.view(-1, self.num_labels), labels.view(-1, self.num_labels)
            )
            return loss
        else:
            return logits

其他

关于数据预处理,DataLoader等代码有点多,这里不一一列举,需要代码的在公众号回复:”bert“ 。

最后的训练结果如下所示:

小白学习pytorch教程十二迁移学习:微调vgg19实现图像分类(代码片段)

@Author:Runsen前言:迁移学习就是利用数据、任务或模型之间的相似性,将在旧的领域学习过或训练好的模型,应用于新的领域这样的一个过程。从这段定义里面,我们可以窥见迁移学习的关键点所在,... 查看详情

小白学习pytorch教程十五bert:通过pytorch来创建一个文本分类的bert模型(代码片段)

@Author:Runsen2018年,谷歌发表了一篇题为《Pre-trainingofdeepbidirectionalTransformersforLanguageUnderstanding》的论文。在本文中,介绍了一种称为BERT(带转换器Transformers的双向编码Encoder器表示)的语言模型,该模型在问答、自然语言推理、... 查看详情

小白学习pytorch教程十四迁移学习:微调resnet实现男人和女人图像分类(代码片段)

@Author:Runsen上次微调了Alexnet,这次微调ResNet实现男人和女人图像分类。ResNet是ResidualNetworks的缩写,是一种经典的神经网络,用作许多计算机视觉任务。ResNet论文参见此处:https://arxiv.org/abs/1512.03385该模型... 查看详情

小白学习pytorch教程十三迁移学习:微调alexnet实现ant和bee图像分类(代码片段)

@Author:Runsen上次微调了VGG19,这次微调Alexnet实现ant和bee图像分类。多年来,CNN许多变体已经发展起来,从而产生了几种CNN架构。其中最常见的是:LeNet-5(1998)AlexNet(2012)ZFNet(2013)GoogleNet/Inception(2014 查看详情

小白学习pytorch教程七基于乳腺癌数据集​​构建logistic二分类模型(代码片段)

...、文字分类都属于这一类。在这篇博客中,将学习如何在PyTorch中实现逻辑回归。文章目录1.数据集加载2.预处理3.模型搭建4.训练和优化1.数据集加载在这里,我将使用来自sklearn库的乳腺癌数据集。这是一个简单的二元类分类数据... 查看详情

pytorch100例|用深度学习处理分类问题实战教程(代码片段)

PyTorch和TensorFlow库是用于深度学习的两个最常用的Python库。PyTorch是Facebook开发的,而TensorFlow是Google的项目。在本文中,你将看到如何使用PyTorch库来解决分类问题。分类问题属于机器学习问题的范畴,其中给定一组特... 查看详情

小白学习pytorch教程八使用图像数据增强手段,提升cifar-10数据集精确度(代码片段)

@Author:Runsen上次基于CIFAR-10数据集,使用PyTorch​​构建图像分类模型的精确度是60%,对于如何提升精确度,方法就是常见的transforms图像数据增强手段。importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.datai 查看详情

在多标签图像分类任务中,哪个损失函数会收敛得很好?

】在多标签图像分类任务中,哪个损失函数会收敛得很好?【英文标题】:Whichlossfunctionwillconvergewellinmulti-labelimageclassificationtask?【发布时间】:2020-05-2518:52:23【问题描述】:我通过使用sigmoid作为输出激活函数和binary_crossentropy作... 查看详情

小白学习pytorch教程十基于大型电影评论数据集训练第一个lstm模型(代码片段)

@Author:Runsen文章目录编码建立字典并对评论进行编码对标签进行编码删除异常值填充序列数据加载器RNN模型的实现训练本博客对原始IMDB数据集进行预处理,建立一个简单的深层神经网络模型,对给定数据进行情感分析... 查看详情

小白学习pytorch教程十七基于torch实现unet图像分割模型(代码片段)

@Author:Runsen在图像领域,除了分类,CNN今天还用于更高级的问题,如图像分割、对象检测等。图像分割是计算机视觉中的一个过程,其中图像被分割成代表图像中每个不同类别的不同段。上面图片一段代表... 查看详情

小白学习pytorch教程十七pytorch中数据集torchvision和torchtext(代码片段)

@Author:Runsen对于PyTorch加载和处理不同类型数据,官方提供了torchvision和torchtext。之前使用torchDataLoader类直接加载图像并将其转换为张量。现在结合torchvision和torchtext介绍torch中的内置数据集Torchvision中的数据集MNISTMNIST... 查看详情

小白学习pytorch教程九基于pytorch训练第一个rnn模型(代码片段)

@Author:Runsen当阅读一篇课文时,我们可以根据前面的单词来理解每个单词的,而不是从零开始理解每个单词。这可以称为记忆。卷积神经网络模型(CNN)不能实现这种记忆,因此引入了递归神经网络模型(RNN)来解决这一问题... 查看详情

pytorch100例|用深度学习处理分类问题实战教程(代码片段)

PyTorch和TensorFlow库是用于深度学习的两个最常用的Python库。PyTorch是Facebook开发的,而TensorFlow是Google的项目。在本文中,你将看到如何使用PyTorch库来解决分类问题。分类问题属于机器学习问题的范畴,其中给定一组特... 查看详情

pytorch100例|用深度学习处理分类问题实战教程(代码片段)

PyTorch和TensorFlow库是用于深度学习的两个最常用的Python库。PyTorch是Facebook开发的,而TensorFlow是Google的项目。在本文中,你将看到如何使用PyTorch库来解决分类问题。分类问题属于机器学习问题的范畴,其中给定一组特... 查看详情

小白学习keras教程二基于cifar-10数据集训练简单的mlp分类模型(代码片段)

@Author:Runsen分类任务的MLP当目标(y)是离散的(分类的)对于损失函数,使用交叉熵;对于评估指标,通常使用accuracy数据集描述CIFAR-10数据集包含10个类中的60000个图像—50000个用于培训,10000个用于测试有关更多信息,请参阅... 查看详情

机器学习模型性能度量详解python机器学习系列(十六)(代码片段)

机器学习模型性能度量【Python机器学习系列(十六)】文章目录1.错误率与精度2.查准率查全率与F12.1TP、FP、FN、TN2.2查准率与查全率2.3P-R图与平衡点2.4F12.5宏查准率&宏查全率&宏F12.6微查准率&微查全率&微F13.ROC... 查看详情

机器学习模型性能度量详解python机器学习系列(十六)(代码片段)

机器学习模型性能度量【Python机器学习系列(十六)】文章目录1.错误率与精度2.查准率查全率与F12.1TP、FP、FN、TN2.2查准率与查全率2.3P-R图与平衡点2.4F12.5宏查准率&宏查全率&宏F12.6微查准率&微查全率&微F13.ROC... 查看详情

小白学习pytorch教程十七pytorch中数据集torchvision和torchtext(代码片段)

@Author:Runsen对于PyTorch加载和处理不同类型数据,官方提供了torchvision和torchtext。之前使用torchDataLoader类直接加载图像并将其转换为张量。现在结合torchvision和torchtext介绍torch中的内置数据集Torchvision中的数据集MNISTMNIST... 查看详情