gan实战之pytorch使用pix2pixgan生成建筑物图片(代码片段)

码农男孩 码农男孩     2023-03-09     578

关键词:

pix2pixGAN的相关原理说明参考:

GAN笔记-- pix2pixGAN 网络原理介绍以及论文解读

一、数据集加载

       数据集下载:pix2pixGAN训练数据集,建筑物数据集

解压分成后有两个文件夹base和extend,我们将base作为训练集,extend作为测试集进行测试

imgs_path = glob.glob('dataset/base/*.jpg')
annos_path = glob.glob('dataset/base/*.png')

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Resize((256, 256)),
                                transforms.Normalize(0.5,0.5)])

class CMP_dataset(data.Dataset):
    def __init__(self, imgs_path, annos_path):
        self.imgs_path = imgs_path
        self.annos_path = annos_path
    def __getitem__(self, index):
        img_path = self.imgs_path[index]
        anno_path = self.annos_path[index]
        pil_img = Image.open(img_path)
        pil_img = transform(pil_img)
        anno_img = Image.open(anno_path)
        anno_img = anno_img.convert('RGB')
        pil_anno = transform(anno_img)
        return pil_anno, pil_img
    def __len__(self):
        return len(imgs_path)

dataset = CMP_dataset(imgs_path, annos_path)
dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True)

二、定义下采样模块和上采样模块

class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.conv_relu = nn.Sequential(
                            nn.Conv2d(in_channels, out_channels,
                                      kernel_size=3,
                                      stride=2,
                                      padding=1),
                            nn.LeakyReLU(inplace=True),
            )
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x, is_bn=True):
        x = self.conv_relu(x)
        if is_bn:
            x = self.bn(x)
        return x
# 定义上采用模块
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Upsample, self).__init__()
        self.upconv_relu = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.LeakyReLU(inplace=True)
        )
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x, is_drop=False):
        x = self.upconv_relu(x)
        x = self.bn(x)
        if is_drop:
            x = F.dropout2d(x)
        return x

三、定义生成器

我们将定义六层下采样,五层上采样再外加一个输出层,基于UNet结构进行编码

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down1 = Downsample(3, 64)
        self.down2 = Downsample(64, 128)
        self.down3 = Downsample(128, 256)
        self.down4 = Downsample(256, 512)
        self.down5 = Downsample(512, 512)
        self.down6 = Downsample(512, 512)

        self.up1 = Upsample(512, 512)
        self.up2 = Upsample(1024, 512)
        self.up3 = Upsample(1024, 256)
        self.up4 = Upsample(512, 128)
        self.up5 = Upsample(256, 64)

        self.last = nn.ConvTranspose2d(128, 3,
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1)

    def forward(self, x):
        x1 = self.down1(x, is_bn=False)  # torch.Size([8, 64, 128, 128])
        x2 = self.down2(x1)  # torch.Size([8, 128, 64, 64])
        x3 = self.down3(x2)  # torch.Size([8, 256, 32, 32])
        x4 = self.down4(x3)  # torch.Size([8, 512, 16, 16])
        x5 = self.down5(x4)  # torch.Size([8, 512, 8, 8])
        x6 = self.down6(x5)  # torch.Size([8, 512, 4, 4])

        x6 = self.up1(x6, is_drop=True)  # torch.Size([8, 512, 8, 8])
        x6 = torch.cat([x5, x6], dim=1)  # torch.Size([8, 1024, 8, 8])

        x6 = self.up2(x6, is_drop=True)  # torch.Size([8, 512, 16, 16])
        x6 = torch.cat([x4, x6], dim=1)  # torch.Size([8, 1024, 16, 16])

        x6 = self.up3(x6, is_drop=True)
        x6 = torch.cat([x3, x6], dim=1)

        x6 = self.up4(x6)
        x6 = torch.cat([x2, x6], dim=1)

        x6 = self.up5(x6)
        x6 = torch.cat([x1, x6], dim=1)

        x6 = torch.tanh(self.last(x6))
        return x6

四、定义判别器

# 判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.down1 = Downsample(6, 64)
        self.down2 = Downsample(64, 128)
        self.down3 = Downsample(128, 256)
        self.conv = nn.Conv2d(256, 512, 3, 1, 1)
        self.bn = nn.BatchNorm2d(512)
        self.last = nn.Conv2d(512, 1, 3, 1)

    def forward(self, anno, img):
        x = torch.cat([anno, img], dim=1)  # batch*6*H*W
        x = self.down1(x, is_bn=False)
        x = self.down2(x)
        x = F.dropout2d(self.down3(x))
        x = F.dropout2d(F.leaky_relu(self.conv(x)))
        x = F.dropout2d(self.bn(x))
        x = torch.sigmoid(self.last(x))
        return x

五、完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision
from torchvision import transforms

import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from PIL import Image

imgs_path = glob.glob('dataset/base/*.jpg')
annos_path = glob.glob('dataset/base/*.png')

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Resize((256, 256)),
                                transforms.Normalize(0.5,0.5)])

class CMP_dataset(data.Dataset):
    def __init__(self, imgs_path, annos_path):
        self.imgs_path = imgs_path
        self.annos_path = annos_path
    def __getitem__(self, index):
        img_path = self.imgs_path[index]
        anno_path = self.annos_path[index]
        pil_img = Image.open(img_path)
        pil_img = transform(pil_img)
        anno_img = Image.open(anno_path)
        anno_img = anno_img.convert('RGB')
        pil_anno = transform(anno_img)
        return pil_anno, pil_img
    def __len__(self):
        return len(imgs_path)

dataset = CMP_dataset(imgs_path, annos_path)
dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True)


# 创建模型
# 定义下采样模块
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.conv_relu = nn.Sequential(
                            nn.Conv2d(in_channels, out_channels,
                                      kernel_size=3,
                                      stride=2,
                                      padding=1),
                            nn.LeakyReLU(inplace=True),
            )
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x, is_bn=True):
        x = self.conv_relu(x)
        if is_bn:
            x = self.bn(x)
        return x

# 定义上采用模块
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Upsample, self).__init__()
        self.upconv_relu = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.LeakyReLU(inplace=True)
        )
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x, is_drop=False):
        x = self.upconv_relu(x)
        x = self.bn(x)
        if is_drop:
            x = F.dropout2d(x)
        return x

# 定义生成器:六个下采样层,五个上采样层,一个输出层
# UNet结构
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down1 = Downsample(3, 64)
        self.down2 = Downsample(64, 128)
        self.down3 = Downsample(128, 256)
        self.down4 = Downsample(256, 512)
        self.down5 = Downsample(512, 512)
        self.down6 = Downsample(512, 512)

        self.up1 = Upsample(512, 512)
        self.up2 = Upsample(1024, 512)
        self.up3 = Upsample(1024, 256)
        self.up4 = Upsample(512, 128)
        self.up5 = Upsample(256, 64)

        self.last = nn.ConvTranspose2d(128, 3,
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1)

    def forward(self, x):
        x1 = self.down1(x, is_bn=False)  # torch.Size([8, 64, 128, 128])
        x2 = self.down2(x1)  # torch.Size([8, 128, 64, 64])
        x3 = self.down3(x2)  # torch.Size([8, 256, 32, 32])
        x4 = self.down4(x3)  # torch.Size([8, 512, 16, 16])
        x5 = self.down5(x4)  # torch.Size([8, 512, 8, 8])
        x6 = self.down6(x5)  # torch.Size([8, 512, 4, 4])

        x6 = self.up1(x6, is_drop=True)  # torch.Size([8, 512, 8, 8])
        x6 = torch.cat([x5, x6], dim=1)  # torch.Size([8, 1024, 8, 8])

        x6 = self.up2(x6, is_drop=True)  # torch.Size([8, 512, 16, 16])
        x6 = torch.cat([x4, x6], dim=1)  # torch.Size([8, 1024, 16, 16])

        x6 = self.up3(x6, is_drop=True)
        x6 = torch.cat([x3, x6], dim=1)

        x6 = self.up4(x6)
        x6 = torch.cat([x2, x6], dim=1)

        x6 = self.up5(x6)
        x6 = torch.cat([x1, x6], dim=1)

        x6 = torch.tanh(self.last(x6))
        return x6

# 判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.down1 = Downsample(6, 64)
        self.down2 = Downsample(64, 128)
        self.down3 = Downsample(128, 256)
        self.conv = nn.Conv2d(256, 512, 3, 1, 1)
        self.bn = nn.BatchNorm2d(512)
        self.last = nn.Conv2d(512, 1, 3, 1)

    def forward(self, anno, img):
        x = torch.cat([anno, img], dim=1)  # batch*6*H*W
        x = self.down1(x, is_bn=False)
        x = self.down2(x)
        x = F.dropout2d(self.down3(x))
        x = F.dropout2d(F.leaky_relu(self.conv(x)))
        x = F.dropout2d(self.bn(x))
        x = torch.sigmoid(self.last(x))
        return x

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)

# 定义损失函数
# cgan损失函数
loss_fn = torch.nn.BCELoss()

d_optimizer = torch.optim.Adam(dis.parameters(), lr=2e-4, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))

# 绘图函数
def generate_images(model, test_input, true_target):
    prediction = model(test_input).permute(0, 2, 3, 1).detach().cpu().numpy()
    test_input = test_input.permute(0, 2, 3, 1).cpu().numpy()
    true_target = true_target.permute(0, 2, 3, 1).cpu().numpy()
    plt.figure(figsize=(5, 5))
    display_list = [test_input[0], true_target[0], prediction[0]]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()

test_imgs_path = glob.glob('dataset/extended/*.jpg')
test_annos_path = glob.glob('dataset/extended/*.png')

test_dataset = CMP_dataset(test_imgs_path, test_annos_path)
test_dataloader = data.DataLoader(test_dataset, batch_size=32)

LAMBDA = 10  # L1损失的权重

annos_batch, imgs_batch = next(iter(test_dataloader))
imgs_batch = imgs_batch.to(device)
annos_batch = annos_batch.to(device)

D_loss = []  # 记录训练过程中判别器loss变化
G_loss = []  # 记录训练过程中生成器loss变化

# 开始训练
for epoch in range(100):
    D_epoch_loss = 0
    G_epoch_loss = 0
    count = len(dataloader)
    for step, (annos, imgs) in enumerate(dataloader):
        imgs = imgs.to(device)
        annos = annos.to(device)

        d_optimizer.zero_grad()
        disc_real_output = dis(annos, imgs)  # 判别器输入真实图片
        d_real_loss = loss_fn(disc_real_output,
                              torch.ones_like(disc_real_output, device=device))
        d_real_loss.backward()

        # 生成器输入随机张量得到生成图片
        gen_output = gen(annos)
        disc_gen_output = dis(annos, gen_output.detach())
        d_fake_loss = loss_fn(disc_gen_output,
                              torch.zeros_like(disc_gen_output, device=device))
        d_fake_loss.backward()

        disc_loss = d_real_loss + d_fake_loss  # 判别器的总损失
        d_optimizer.step()

        g_optimizer.zero_grad()

        disc_gen_output = dis(annos, gen_output)  # 判别器输入生成图像
        gen_loss_crossentropy = loss_fn(disc_gen_output,
                                        torch.ones_like(disc_gen_output, device=device))
        gen_l1_loss = torch.mean(torch.abs(imgs - gen_output))  # L1损失
        gen_loss = gen_loss_crossentropy + (LAMBDA * gen_l1_loss)
        gen_loss.backward()
        g_optimizer.step()

        with torch.no_grad():
            D_epoch_loss += disc_loss.item()
            G_epoch_loss += gen_loss.item()
    #            generate_images(gen, imgs_batch, musks_batch)
    with torch.no_grad():
        D_epoch_loss /= count
        G_epoch_loss /= count
        D_loss.append(D_epoch_loss)
        G_loss.append(G_epoch_loss)
        # 训练完一个Epoch,打印提示并绘制生成的图片
        print("Epoch:", epoch)
        generate_images(gen, annos_batch, imgs_batch)

六、结果展示

分别为Epoch0、Epoch50、Epoch100

《新程序员》:云原生和全面数字化实践 50位技术专家共同创作,文字、视频、音频交互阅读

pytorch搭建基本的gan模型及训练过程(代码片段)

...数绘图函数GAN的训练输出整体代码参考资料概述本文通过Pytorch搭建基本的GAN模型结构,并通过torchvision的MNIST数据集进行测试。对于GAN模型的基本结构及公式的理解可以看前一篇博客:GAN的理论知识及公式的理解下文的实... 查看详情

深度学习理论与实战pytorch实现

课程目录:01.预备内容(入门)02.Python基础(入门)03.PyTorch基础(入门)04.神经网络(进阶)05.卷积神经网络(进阶)06.循环神经网络(进阶)07.生成对抗网络GAN(进阶)08.强化学习(进阶)09.毕业项目 下载地址:深度学习理... 查看详情

手撕cnn经典网络之vggnet(pytorch实战篇)

...)详细介绍了VGGNet的网络结构,今天我们将使用PyTorch来复现VGGNet网络,并用VGGNet模型来解决一个经典的Kaggle图像识别比赛问题。正文开始!1.数据集制作在论文中AlexNet作者使用的是 查看详情

手撕cnn经典网络之vggnet(pytorch实战篇)

...)详细介绍了VGGNet的网络结构,今天我们将使用PyTorch来复现VGGNet网络,并用VGGNet模型来解决一个经典的Kaggle图像识别比赛问题。正文开始!1.数据集制作在论文中AlexNet作者使用的是 查看详情

Pytorch GAN 模型未训练:矩阵乘法错误

】PytorchGAN模型未训练:矩阵乘法错误【英文标题】:PytorchGANmodeldoesn\'ttrain:matrixmultiplicationerror【发布时间】:2021-07-1217:47:22【问题描述】:我正在尝试构建一个基本的GAN来熟悉Pytorch。我对Keras有一些(有限的)经验,但由于我... 查看详情

手撕cnn之alexnet(pytorch实战篇)

...#xff09;详细介绍了AlexNet的网络结构,今天我们将使用 PyTorch 来复现AlexNet网络,并用AlexNet模型来解决一个经典的Kaggle图像识别比赛问题。正文开始!1. 数据集制作在论文中AlexNe 查看详情

手撕cnn之alexnet(pytorch实战篇)

...#xff09;详细介绍了AlexNet的网络结构,今天我们将使用 PyTorch 来复现AlexNet网络,并用AlexNet模型来解决一个经典的Kaggle图像识别比赛问题。正文开始!1. 数据集制作在论文中AlexNe 查看详情

手把手写深度学习(11):用pix2pixgans实现sketch-to-image跨模态任务(实战演练)

前言:上一篇讲了很多Pix2PixGANs的基础,没想到越写越多,最后写了将近4w字,所以拆分成两篇:理论基础和实战演练分别介绍。这篇讲怎样用pix2pixGANs完成sketch-to-image任务,重点放在代码演练上。目录计算对抗损失和L1损失总体... 查看详情

学习打卡07可解释机器学习笔记之shape+lime代码实战(代码片段)

...记之Shape+Lime代码实战基于Shapley值的可解释性分析使用Pytorch对MNIST分类可解释性分析使用shap的DeepExplainer进行可视化使用Pytorch对预训练ImageNet图像分类可解释性分析指定单个预测类别指定多个预测类别前k个预测类别LIME代码实... 查看详情

学习打卡07可解释机器学习笔记之shape+lime代码实战(代码片段)

...记之Shape+Lime代码实战基于Shapley值的可解释性分析使用Pytorch对MNIST分类可解释性分析使用shap的DeepExplainer进行可视化使用Pytorch对预训练ImageNet图像分类可解释性分析指定单个预测类别指定多个预测类别前k个预测类别LIME代码实... 查看详情

翻车现场:我用pytorch和gan做了一个生成神奇宝贝的失败模型(代码片段)

前言神奇宝贝已经是一个家喻户晓的动画了,我们今天来确认是否可以使用深度学习为他自动创建新的Pokemon。我最终成功地使用了生成对抗网络(GAN)生成了类似Pokemon的图像,但是这个图像看起来并不像神奇宝贝。虽然这个尝... 查看详情

pytorch生成对抗网络gan的基础教学简单实例(附代码数据集)

1.简介这篇文章主要是介绍了使用pytorch框架构建生成对抗网络GAN来生成虚假图像的原理与简单实例代码。数据集使用的是开源人脸图像数据集img_align_celeba,共1.34G。生成器与判别器模型均采用简单的卷积结构,代码参考... 查看详情

gan-生成对抗神经网络(pytorch)-合集gan-dcgan-cgan(代码片段)

原生GAN(GenerativeAdversarialNets)训练过程也是老三步了,再啰嗦一遍:使用真实图片训练辨别器,标签为真使用生成器生成的图片训练判别器,标签为假,此时图片使用生成器计算得来的,喂给判别... 查看详情

gan-生成对抗神经网络(pytorch)-合集gan-dcgan-cgan(代码片段)

原生GAN(GenerativeAdversarialNets)训练过程也是老三步了,再啰嗦一遍:使用真实图片训练辨别器,标签为真使用生成器生成的图片训练判别器,标签为假,此时图片使用生成器计算得来的,喂给判别... 查看详情

美团云tensorflow生成对抗网络(generativeadversarialnetworks)实战案例(代码片段)

生成对抗网络GAN概述GAN的前世今生GAN的基本原理GAN的代码实现在DLS运行的注意事项数据文件的读取数据文件的写入GAN示例代码概述本文主要介绍GAN的基本知识,以及在DLS上运行的注意事项。本模块继续通过经典的MNIST数据集... 查看详情

对比学习:《深度学习之pytorch》《pytorch深度学习实战》+代码

PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉、自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的AllenNLP,用于概率图模型的Pyro,扩展了PyTorch的功能。通... 查看详情

将 torch.backward() 用于 GAN 生成器时,为啥 Pytorch 中的判别器损失没有变化?

】将torch.backward()用于GAN生成器时,为啥Pytorch中的判别器损失没有变化?【英文标题】:Whenusingtorch.backward()foraGANsgenerator,whydoesn\'tdiscriminatorlosseschangeinPytorch?将torch.backward()用于GAN生成器时,为什么Pytorch中的判别器损失没有变化?... 查看详情

有没有pytorch的高级应用教程推荐一下

深度学习与PyTorch入门实战教程百度网盘免费资源在线学习   链接:https://pan.baidu.com/s/1BQtdabkn-aj0SXVv3x62ww 提取码:r93i  深度学习与PyTorch入门实战教程9.卷积神经网络CNN8.过拟合7.神经网络与全连接层6.随机梯度下... 查看详情