关键词:
pix2pixGAN的相关原理说明参考:
GAN笔记-- pix2pixGAN 网络原理介绍以及论文解读
一、数据集加载
数据集下载:pix2pixGAN训练数据集,建筑物数据集
解压分成后有两个文件夹base和extend,我们将base作为训练集,extend作为测试集进行测试
imgs_path = glob.glob('dataset/base/*.jpg')
annos_path = glob.glob('dataset/base/*.png')
transform = transforms.Compose([transforms.ToTensor(),
transforms.Resize((256, 256)),
transforms.Normalize(0.5,0.5)])
class CMP_dataset(data.Dataset):
def __init__(self, imgs_path, annos_path):
self.imgs_path = imgs_path
self.annos_path = annos_path
def __getitem__(self, index):
img_path = self.imgs_path[index]
anno_path = self.annos_path[index]
pil_img = Image.open(img_path)
pil_img = transform(pil_img)
anno_img = Image.open(anno_path)
anno_img = anno_img.convert('RGB')
pil_anno = transform(anno_img)
return pil_anno, pil_img
def __len__(self):
return len(imgs_path)
dataset = CMP_dataset(imgs_path, annos_path)
dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True)
二、定义下采样模块和上采样模块
class Downsample(nn.Module):
def __init__(self, in_channels, out_channels):
super(Downsample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels, out_channels,
kernel_size=3,
stride=2,
padding=1),
nn.LeakyReLU(inplace=True),
)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x, is_bn=True):
x = self.conv_relu(x)
if is_bn:
x = self.bn(x)
return x
# 定义上采用模块
class Upsample(nn.Module):
def __init__(self, in_channels, out_channels):
super(Upsample, self).__init__()
self.upconv_relu = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels,
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.LeakyReLU(inplace=True)
)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x, is_drop=False):
x = self.upconv_relu(x)
x = self.bn(x)
if is_drop:
x = F.dropout2d(x)
return x
三、定义生成器
我们将定义六层下采样,五层上采样再外加一个输出层,基于UNet结构进行编码
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.down1 = Downsample(3, 64)
self.down2 = Downsample(64, 128)
self.down3 = Downsample(128, 256)
self.down4 = Downsample(256, 512)
self.down5 = Downsample(512, 512)
self.down6 = Downsample(512, 512)
self.up1 = Upsample(512, 512)
self.up2 = Upsample(1024, 512)
self.up3 = Upsample(1024, 256)
self.up4 = Upsample(512, 128)
self.up5 = Upsample(256, 64)
self.last = nn.ConvTranspose2d(128, 3,
kernel_size=3,
stride=2,
padding=1,
output_padding=1)
def forward(self, x):
x1 = self.down1(x, is_bn=False) # torch.Size([8, 64, 128, 128])
x2 = self.down2(x1) # torch.Size([8, 128, 64, 64])
x3 = self.down3(x2) # torch.Size([8, 256, 32, 32])
x4 = self.down4(x3) # torch.Size([8, 512, 16, 16])
x5 = self.down5(x4) # torch.Size([8, 512, 8, 8])
x6 = self.down6(x5) # torch.Size([8, 512, 4, 4])
x6 = self.up1(x6, is_drop=True) # torch.Size([8, 512, 8, 8])
x6 = torch.cat([x5, x6], dim=1) # torch.Size([8, 1024, 8, 8])
x6 = self.up2(x6, is_drop=True) # torch.Size([8, 512, 16, 16])
x6 = torch.cat([x4, x6], dim=1) # torch.Size([8, 1024, 16, 16])
x6 = self.up3(x6, is_drop=True)
x6 = torch.cat([x3, x6], dim=1)
x6 = self.up4(x6)
x6 = torch.cat([x2, x6], dim=1)
x6 = self.up5(x6)
x6 = torch.cat([x1, x6], dim=1)
x6 = torch.tanh(self.last(x6))
return x6
四、定义判别器
# 判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.down1 = Downsample(6, 64)
self.down2 = Downsample(64, 128)
self.down3 = Downsample(128, 256)
self.conv = nn.Conv2d(256, 512, 3, 1, 1)
self.bn = nn.BatchNorm2d(512)
self.last = nn.Conv2d(512, 1, 3, 1)
def forward(self, anno, img):
x = torch.cat([anno, img], dim=1) # batch*6*H*W
x = self.down1(x, is_bn=False)
x = self.down2(x)
x = F.dropout2d(self.down3(x))
x = F.dropout2d(F.leaky_relu(self.conv(x)))
x = F.dropout2d(self.bn(x))
x = torch.sigmoid(self.last(x))
return x
五、完整代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from PIL import Image
imgs_path = glob.glob('dataset/base/*.jpg')
annos_path = glob.glob('dataset/base/*.png')
transform = transforms.Compose([transforms.ToTensor(),
transforms.Resize((256, 256)),
transforms.Normalize(0.5,0.5)])
class CMP_dataset(data.Dataset):
def __init__(self, imgs_path, annos_path):
self.imgs_path = imgs_path
self.annos_path = annos_path
def __getitem__(self, index):
img_path = self.imgs_path[index]
anno_path = self.annos_path[index]
pil_img = Image.open(img_path)
pil_img = transform(pil_img)
anno_img = Image.open(anno_path)
anno_img = anno_img.convert('RGB')
pil_anno = transform(anno_img)
return pil_anno, pil_img
def __len__(self):
return len(imgs_path)
dataset = CMP_dataset(imgs_path, annos_path)
dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True)
# 创建模型
# 定义下采样模块
class Downsample(nn.Module):
def __init__(self, in_channels, out_channels):
super(Downsample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels, out_channels,
kernel_size=3,
stride=2,
padding=1),
nn.LeakyReLU(inplace=True),
)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x, is_bn=True):
x = self.conv_relu(x)
if is_bn:
x = self.bn(x)
return x
# 定义上采用模块
class Upsample(nn.Module):
def __init__(self, in_channels, out_channels):
super(Upsample, self).__init__()
self.upconv_relu = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels,
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.LeakyReLU(inplace=True)
)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x, is_drop=False):
x = self.upconv_relu(x)
x = self.bn(x)
if is_drop:
x = F.dropout2d(x)
return x
# 定义生成器:六个下采样层,五个上采样层,一个输出层
# UNet结构
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.down1 = Downsample(3, 64)
self.down2 = Downsample(64, 128)
self.down3 = Downsample(128, 256)
self.down4 = Downsample(256, 512)
self.down5 = Downsample(512, 512)
self.down6 = Downsample(512, 512)
self.up1 = Upsample(512, 512)
self.up2 = Upsample(1024, 512)
self.up3 = Upsample(1024, 256)
self.up4 = Upsample(512, 128)
self.up5 = Upsample(256, 64)
self.last = nn.ConvTranspose2d(128, 3,
kernel_size=3,
stride=2,
padding=1,
output_padding=1)
def forward(self, x):
x1 = self.down1(x, is_bn=False) # torch.Size([8, 64, 128, 128])
x2 = self.down2(x1) # torch.Size([8, 128, 64, 64])
x3 = self.down3(x2) # torch.Size([8, 256, 32, 32])
x4 = self.down4(x3) # torch.Size([8, 512, 16, 16])
x5 = self.down5(x4) # torch.Size([8, 512, 8, 8])
x6 = self.down6(x5) # torch.Size([8, 512, 4, 4])
x6 = self.up1(x6, is_drop=True) # torch.Size([8, 512, 8, 8])
x6 = torch.cat([x5, x6], dim=1) # torch.Size([8, 1024, 8, 8])
x6 = self.up2(x6, is_drop=True) # torch.Size([8, 512, 16, 16])
x6 = torch.cat([x4, x6], dim=1) # torch.Size([8, 1024, 16, 16])
x6 = self.up3(x6, is_drop=True)
x6 = torch.cat([x3, x6], dim=1)
x6 = self.up4(x6)
x6 = torch.cat([x2, x6], dim=1)
x6 = self.up5(x6)
x6 = torch.cat([x1, x6], dim=1)
x6 = torch.tanh(self.last(x6))
return x6
# 判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.down1 = Downsample(6, 64)
self.down2 = Downsample(64, 128)
self.down3 = Downsample(128, 256)
self.conv = nn.Conv2d(256, 512, 3, 1, 1)
self.bn = nn.BatchNorm2d(512)
self.last = nn.Conv2d(512, 1, 3, 1)
def forward(self, anno, img):
x = torch.cat([anno, img], dim=1) # batch*6*H*W
x = self.down1(x, is_bn=False)
x = self.down2(x)
x = F.dropout2d(self.down3(x))
x = F.dropout2d(F.leaky_relu(self.conv(x)))
x = F.dropout2d(self.bn(x))
x = torch.sigmoid(self.last(x))
return x
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
# 定义损失函数
# cgan损失函数
loss_fn = torch.nn.BCELoss()
d_optimizer = torch.optim.Adam(dis.parameters(), lr=2e-4, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))
# 绘图函数
def generate_images(model, test_input, true_target):
prediction = model(test_input).permute(0, 2, 3, 1).detach().cpu().numpy()
test_input = test_input.permute(0, 2, 3, 1).cpu().numpy()
true_target = true_target.permute(0, 2, 3, 1).cpu().numpy()
plt.figure(figsize=(5, 5))
display_list = [test_input[0], true_target[0], prediction[0]]
title = ['Input Image', 'Ground Truth', 'Predicted Image']
for i in range(3):
plt.subplot(1, 3, i+1)
plt.title(title[i])
plt.imshow(display_list[i] * 0.5 + 0.5)
plt.axis('off')
plt.show()
test_imgs_path = glob.glob('dataset/extended/*.jpg')
test_annos_path = glob.glob('dataset/extended/*.png')
test_dataset = CMP_dataset(test_imgs_path, test_annos_path)
test_dataloader = data.DataLoader(test_dataset, batch_size=32)
LAMBDA = 10 # L1损失的权重
annos_batch, imgs_batch = next(iter(test_dataloader))
imgs_batch = imgs_batch.to(device)
annos_batch = annos_batch.to(device)
D_loss = [] # 记录训练过程中判别器loss变化
G_loss = [] # 记录训练过程中生成器loss变化
# 开始训练
for epoch in range(100):
D_epoch_loss = 0
G_epoch_loss = 0
count = len(dataloader)
for step, (annos, imgs) in enumerate(dataloader):
imgs = imgs.to(device)
annos = annos.to(device)
d_optimizer.zero_grad()
disc_real_output = dis(annos, imgs) # 判别器输入真实图片
d_real_loss = loss_fn(disc_real_output,
torch.ones_like(disc_real_output, device=device))
d_real_loss.backward()
# 生成器输入随机张量得到生成图片
gen_output = gen(annos)
disc_gen_output = dis(annos, gen_output.detach())
d_fake_loss = loss_fn(disc_gen_output,
torch.zeros_like(disc_gen_output, device=device))
d_fake_loss.backward()
disc_loss = d_real_loss + d_fake_loss # 判别器的总损失
d_optimizer.step()
g_optimizer.zero_grad()
disc_gen_output = dis(annos, gen_output) # 判别器输入生成图像
gen_loss_crossentropy = loss_fn(disc_gen_output,
torch.ones_like(disc_gen_output, device=device))
gen_l1_loss = torch.mean(torch.abs(imgs - gen_output)) # L1损失
gen_loss = gen_loss_crossentropy + (LAMBDA * gen_l1_loss)
gen_loss.backward()
g_optimizer.step()
with torch.no_grad():
D_epoch_loss += disc_loss.item()
G_epoch_loss += gen_loss.item()
# generate_images(gen, imgs_batch, musks_batch)
with torch.no_grad():
D_epoch_loss /= count
G_epoch_loss /= count
D_loss.append(D_epoch_loss)
G_loss.append(G_epoch_loss)
# 训练完一个Epoch,打印提示并绘制生成的图片
print("Epoch:", epoch)
generate_images(gen, annos_batch, imgs_batch)
六、结果展示
分别为Epoch0、Epoch50、Epoch100
《新程序员》:云原生和全面数字化实践 50位技术专家共同创作,文字、视频、音频交互阅读pytorch搭建基本的gan模型及训练过程(代码片段)
...数绘图函数GAN的训练输出整体代码参考资料概述本文通过Pytorch搭建基本的GAN模型结构,并通过torchvision的MNIST数据集进行测试。对于GAN模型的基本结构及公式的理解可以看前一篇博客:GAN的理论知识及公式的理解下文的实... 查看详情
深度学习理论与实战pytorch实现
课程目录:01.预备内容(入门)02.Python基础(入门)03.PyTorch基础(入门)04.神经网络(进阶)05.卷积神经网络(进阶)06.循环神经网络(进阶)07.生成对抗网络GAN(进阶)08.强化学习(进阶)09.毕业项目 下载地址:深度学习理... 查看详情
手撕cnn经典网络之vggnet(pytorch实战篇)
...)详细介绍了VGGNet的网络结构,今天我们将使用PyTorch来复现VGGNet网络,并用VGGNet模型来解决一个经典的Kaggle图像识别比赛问题。正文开始!1.数据集制作在论文中AlexNet作者使用的是 查看详情
手撕cnn经典网络之vggnet(pytorch实战篇)
...)详细介绍了VGGNet的网络结构,今天我们将使用PyTorch来复现VGGNet网络,并用VGGNet模型来解决一个经典的Kaggle图像识别比赛问题。正文开始!1.数据集制作在论文中AlexNet作者使用的是 查看详情
Pytorch GAN 模型未训练:矩阵乘法错误
】PytorchGAN模型未训练:矩阵乘法错误【英文标题】:PytorchGANmodeldoesn\'ttrain:matrixmultiplicationerror【发布时间】:2021-07-1217:47:22【问题描述】:我正在尝试构建一个基本的GAN来熟悉Pytorch。我对Keras有一些(有限的)经验,但由于我... 查看详情
手撕cnn之alexnet(pytorch实战篇)
...#xff09;详细介绍了AlexNet的网络结构,今天我们将使用 PyTorch 来复现AlexNet网络,并用AlexNet模型来解决一个经典的Kaggle图像识别比赛问题。正文开始!1. 数据集制作在论文中AlexNe 查看详情
手撕cnn之alexnet(pytorch实战篇)
...#xff09;详细介绍了AlexNet的网络结构,今天我们将使用 PyTorch 来复现AlexNet网络,并用AlexNet模型来解决一个经典的Kaggle图像识别比赛问题。正文开始!1. 数据集制作在论文中AlexNe 查看详情
手把手写深度学习(11):用pix2pixgans实现sketch-to-image跨模态任务(实战演练)
前言:上一篇讲了很多Pix2PixGANs的基础,没想到越写越多,最后写了将近4w字,所以拆分成两篇:理论基础和实战演练分别介绍。这篇讲怎样用pix2pixGANs完成sketch-to-image任务,重点放在代码演练上。目录计算对抗损失和L1损失总体... 查看详情
学习打卡07可解释机器学习笔记之shape+lime代码实战(代码片段)
...记之Shape+Lime代码实战基于Shapley值的可解释性分析使用Pytorch对MNIST分类可解释性分析使用shap的DeepExplainer进行可视化使用Pytorch对预训练ImageNet图像分类可解释性分析指定单个预测类别指定多个预测类别前k个预测类别LIME代码实... 查看详情
学习打卡07可解释机器学习笔记之shape+lime代码实战(代码片段)
...记之Shape+Lime代码实战基于Shapley值的可解释性分析使用Pytorch对MNIST分类可解释性分析使用shap的DeepExplainer进行可视化使用Pytorch对预训练ImageNet图像分类可解释性分析指定单个预测类别指定多个预测类别前k个预测类别LIME代码实... 查看详情
翻车现场:我用pytorch和gan做了一个生成神奇宝贝的失败模型(代码片段)
前言神奇宝贝已经是一个家喻户晓的动画了,我们今天来确认是否可以使用深度学习为他自动创建新的Pokemon。我最终成功地使用了生成对抗网络(GAN)生成了类似Pokemon的图像,但是这个图像看起来并不像神奇宝贝。虽然这个尝... 查看详情
pytorch生成对抗网络gan的基础教学简单实例(附代码数据集)
1.简介这篇文章主要是介绍了使用pytorch框架构建生成对抗网络GAN来生成虚假图像的原理与简单实例代码。数据集使用的是开源人脸图像数据集img_align_celeba,共1.34G。生成器与判别器模型均采用简单的卷积结构,代码参考... 查看详情
gan-生成对抗神经网络(pytorch)-合集gan-dcgan-cgan(代码片段)
原生GAN(GenerativeAdversarialNets)训练过程也是老三步了,再啰嗦一遍:使用真实图片训练辨别器,标签为真使用生成器生成的图片训练判别器,标签为假,此时图片使用生成器计算得来的,喂给判别... 查看详情
gan-生成对抗神经网络(pytorch)-合集gan-dcgan-cgan(代码片段)
原生GAN(GenerativeAdversarialNets)训练过程也是老三步了,再啰嗦一遍:使用真实图片训练辨别器,标签为真使用生成器生成的图片训练判别器,标签为假,此时图片使用生成器计算得来的,喂给判别... 查看详情
美团云tensorflow生成对抗网络(generativeadversarialnetworks)实战案例(代码片段)
生成对抗网络GAN概述GAN的前世今生GAN的基本原理GAN的代码实现在DLS运行的注意事项数据文件的读取数据文件的写入GAN示例代码概述本文主要介绍GAN的基本知识,以及在DLS上运行的注意事项。本模块继续通过经典的MNIST数据集... 查看详情
对比学习:《深度学习之pytorch》《pytorch深度学习实战》+代码
PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉、自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的AllenNLP,用于概率图模型的Pyro,扩展了PyTorch的功能。通... 查看详情
将 torch.backward() 用于 GAN 生成器时,为啥 Pytorch 中的判别器损失没有变化?
】将torch.backward()用于GAN生成器时,为啥Pytorch中的判别器损失没有变化?【英文标题】:Whenusingtorch.backward()foraGANsgenerator,whydoesn\'tdiscriminatorlosseschangeinPytorch?将torch.backward()用于GAN生成器时,为什么Pytorch中的判别器损失没有变化?... 查看详情
有没有pytorch的高级应用教程推荐一下
深度学习与PyTorch入门实战教程百度网盘免费资源在线学习 链接:https://pan.baidu.com/s/1BQtdabkn-aj0SXVv3x62ww 提取码:r93i 深度学习与PyTorch入门实战教程9.卷积神经网络CNN8.过拟合7.神经网络与全连接层6.随机梯度下... 查看详情