小白学习pytorch教程十七基于torch实现unet图像分割模型(代码片段)

刘润森! 刘润森!     2022-12-23     328

关键词:

@Author:Runsen

在图像领域,除了分类,CNN 今天还用于更高级的问题,如图像分割、对象检测等。图像分割是计算机视觉中的一个过程,其中图像被分割成代表图像中每个不同类别的不同段。

上面图片一段代表猫,另一段代表背景。

从自动驾驶汽车到卫星,图像分割在许多领域都很有用。其中最重要的是医学成像。

UNet 是一种卷积神经网络架构,在 CNN 架构几乎没有变化的情况下进行了扩展。它的发明是为了处理生物医学图像,其目标不仅是对是否存在感染进行分类,而且还要识别感染区域。

UNet

论文:https://arxiv.org/abs/1505.04597

UNet结构看起来像一个“U”,该架构由三部分组成:收缩部分、瓶颈部分和扩展部分。收缩段由许多收缩块组成。每个块接受一个输入,应用两个 3X3 卷积层,然后是 2X2 最大池化。每个块之后的内核或特征图的数量加倍,以便架构可以有效地学习复杂的结构。最底层介于收缩层和膨胀层之间。它使用两个 3X3 CNN 层,然后是 2X2 上卷积层。

每个块将输入传递给两个 3X3 CNN 层,然后是一个 2X2 上采样层。同样在每个块之后,卷积层使用的特征图数量减半以保持对称性。然而,每次输入也会附加相应收缩层的特征图。此操作将确保在收缩图像时学习的特征将用于重建它。扩展块的数量与收缩块的数量相同。之后,生成的映射通过另一个 3X3 CNN 层,特征映射的数量等于所需的片段数量。

torch实现

使用的数据集是:https://www.kaggle.com/paultimothymooney/chiu-2015

这个数据集用于分割糖尿病性黄斑水肿的光学相干断层扫描图像的图像。

对于mat的数据,使用scipy.io.loadmat进行加载

下面使用 Pytorch 框架实现了 UNet 模型,代码来源下面的Github:https://github.com/Hsankesara/DeepResearch

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

class UNet(nn.Module):
    def contracting_block(self, in_channels, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(out_channels),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(out_channels),
                )
        return block
    
    def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
                    )
            return  block
    
    def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(out_channels),
                    )
            return  block
    
    def __init__(self, in_channel, out_channel):
        super(UNet, self).__init__()
        #Encode
        self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64)
        self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode2 = self.contracting_block(64, 128)
        self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode3 = self.contracting_block(128, 256)
        self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2)
        # Bottleneck
        self.bottleneck = torch.nn.Sequential(
                            torch.nn.Conv2d(kernel_size=3, in_channels=256, out_channels=512),
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm2d(512),
                            torch.nn.Conv2d(kernel_size=3, in_channels=512, out_channels=512),
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm2d(512),
                            torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
                            )
        # Decode
        self.conv_decode3 = self.expansive_block(512, 256, 128)
        self.conv_decode2 = self.expansive_block(256, 128, 64)
        self.final_layer = self.final_block(128, 64, out_channel)
        
    def crop_and_concat(self, upsampled, bypass, crop=False):
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        return torch.cat((upsampled, bypass), 1)
    
    def forward(self, x):
        # Encode
        encode_block1 = self.conv_encode1(x)
        encode_pool1 = self.conv_maxpool1(encode_block1)
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_maxpool2(encode_block2)
        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_maxpool3(encode_block3)
        # Bottleneck
        bottleneck1 = self.bottleneck(encode_pool3)
        # Decode
        decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True)
        cat_layer2 = self.conv_decode3(decode_block3)
        decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True)
        cat_layer1 = self.conv_decode2(decode_block2)
        decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True)
        final_layer = self.final_layer(decode_block1)
        return  final_layer

上面代码中的 UNet 模块代表了 UNet 的整个架构。contraction_block和expansive_block分别用于创建收缩段和膨胀段。该函数crop_and_concat将收缩层的输出与新的扩展层输入相加。

unet = Unet(in_channel=1,out_channel=2)
#out_channel represents number of segments desired
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(unet.parameters(), lr = 0.01, momentum=0.99)
optimizer.zero_grad()       
outputs = unet(inputs)
# permute such that number of desired segments would be on 4th dimension
outputs = outputs.permute(0, 2, 3, 1)
m = outputs.shape[0]
# Resizing the outputs and label to caculate pixel wise softmax loss
outputs = outputs.resize(m*width_out*height_out, 2)
labels = labels.resize(m*width_out*height_out)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

对于该数据集解决标准教程代码:https://www.kaggle.com/hsankesara/unet-image-segmentation

小白学习pytorch教程十七pytorch中数据集torchvision和torchtext(代码片段)

@Author:Runsen对于PyTorch加载和处理不同类型数据,官方提供了torchvision和torchtext。之前使用torchDataLoader类直接加载图像并将其转换为张量。现在结合torchvision和torchtext介绍torch中的内置数据集Torchvision中的数据集MNISTMNIST... 查看详情

小白学习pytorch教程九基于pytorch训练第一个rnn模型(代码片段)

@Author:Runsen当阅读一篇课文时,我们可以根据前面的单词来理解每个单词的,而不是从零开始理解每个单词。这可以称为记忆。卷积神经网络模型(CNN)不能实现这种记忆,因此引入了递归神经网络模型(RNN)来解决这一问题... 查看详情

小白学习pytorch教程十基于大型电影评论数据集训练第一个lstm模型(代码片段)

@Author:Runsen文章目录编码建立字典并对评论进行编码对标签进行编码删除异常值填充序列数据加载器RNN模型的实现训练本博客对原始IMDB数据集进行预处理,建立一个简单的深层神经网络模型,对给定数据进行情感分析... 查看详情

pytorch是啥?

PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。它是一个基于Python的可续计算包,提供两个高级功能:具有强大的GPU加速的张量计算(如NumPy... 查看详情

小白学习pytorch教程七基于乳腺癌数据集​​构建logistic二分类模型(代码片段)

...、文字分类都属于这一类。在这篇博客中,将学习如何在PyTorch中实现逻辑回归。文章目录1.数据集加载2.预处理3.模型搭建4.训练和优化1.数据集加载在这里,我将使用来自sklearn库的乳腺癌数据集。这是一个简单的二元类分类数据... 查看详情

小白学习pytorch教程十一基于mnist数据集训练第一个生成性对抗网络(代码片段)

@Author:RunsenGAN是使用两个神经网络模型训练的生成模型。一种模型称为生成网络模型,它学习生成新的似是而非的样本。另一个模型被称为判别网络,它学习区分生成的例子和真实的例子。生成性对抗网络2014࿰... 查看详情

小白学习pytorch教程十一基于mnist数据集训练第一个生成性对抗网络(代码片段)

@Author:RunsenGAN是使用两个神经网络模型训练的生成模型。一种模型称为生成网络模型,它学习生成新的似是而非的样本。另一个模型被称为判别网络,它学习区分生成的例子和真实的例子。生成性对抗网络2014࿰... 查看详情

小白学习pytorch教程十四迁移学习:微调resnet实现男人和女人图像分类(代码片段)

@Author:Runsen上次微调了Alexnet,这次微调ResNet实现男人和女人图像分类。ResNet是ResidualNetworks的缩写,是一种经典的神经网络,用作许多计算机视觉任务。ResNet论文参见此处:https://arxiv.org/abs/1512.03385该模型... 查看详情

小白学习pytorch教程十三迁移学习:微调alexnet实现ant和bee图像分类(代码片段)

@Author:Runsen上次微调了VGG19,这次微调Alexnet实现ant和bee图像分类。多年来,CNN许多变体已经发展起来,从而产生了几种CNN架构。其中最常见的是:LeNet-5(1998)AlexNet(2012)ZFNet(2013)GoogleNet/Inception(2014 查看详情

小白学习pytorch教程十二迁移学习:微调vgg19实现图像分类(代码片段)

@Author:Runsen前言:迁移学习就是利用数据、任务或模型之间的相似性,将在旧的领域学习过或训练好的模型,应用于新的领域这样的一个过程。从这段定义里面,我们可以窥见迁移学习的关键点所在,... 查看详情

深度学习-pytorch框架实战系列

深度学习-PyTorch框架实战系列PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。它是一个基于Python的可续计算包,提供两个高级功能... 查看详情

小白学习pytorch教程八使用图像数据增强手段,提升cifar-10数据集精确度(代码片段)

@Author:Runsen上次基于CIFAR-10数据集,使用PyTorch​​构建图像分类模型的精确度是60%,对于如何提升精确度,方法就是常见的transforms图像数据增强手段。importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.datai 查看详情

深度学习100例|第41天:语音识别-pytorch实现(代码片段)

...环境配置教程:小白入门深度学习|第四篇:配置PyTorch环境👉往期精彩内容深度学习100例|第1例:猫狗识别-PyTorch实现深度学习100例|第2例:人脸表情识别-PyTorch实现深度学习100例|第3天:交通标志识别-PyTorch... 查看详情

小白学习c++教程十七c++中的字符数组和字符串常见的函数(代码片段)

@Author:Runsen字符数组charmychar[6]='H','e','l','l','o';下面定义的字符串数组在C/C++中的内存表示#include<iostream>usingnamespacestd;intmain()charmycha 查看详情

pytorch60分钟入门教程:pytorch深度学习官方入门中文教程(代码片段)

什么是 PyTorch?PyTorch是一个基于Python的科学计算包,主要定位两类人群:NumPy的替代品,可以利用GPU的性能进行计算。深度学习研究平台拥有足够的灵活性和速度开始学习Tensors(张量)Tensors类似于NumPy的ndarrays,同时 Tensors可... 查看详情

小白学习之pytorch框架-动手学深度学习(begin)

在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比较执着,想学pytorch,好,有个大神来了,把《动... 查看详情

小白入门pytorch|第一篇:什么是pytorch?(代码片段)

什么是PyTorch?这是一个基于Python的科学计算包,主要分入如下2部分:使用GPU的功能代替numpy一个深刻的学习研究平台,提供最大的灵活性和速度开始学习Tensors(张量)Tensors类似于numpy的ndarrays,另外还可以在GPU... 查看详情

小白学习pytorch教程十五bert:通过pytorch来创建一个文本分类的bert模型(代码片段)

@Author:Runsen2018年,谷歌发表了一篇题为《Pre-trainingofdeepbidirectionalTransformersforLanguageUnderstanding》的论文。在本文中,介绍了一种称为BERT(带转换器Transformers的双向编码Encoder器表示)的语言模型,该模型在问答、自然语言推理、... 查看详情