pytorch生成对抗网络gan的基础教学简单实例(附代码数据集)

Lizhi_Tech Lizhi_Tech     2023-03-28     437

关键词:

1.简介

这篇文章主要是介绍了使用pytorch框架构建生成对抗网络GAN来生成虚假图像的原理与简单实例代码。数据集使用的是开源人脸图像数据集img_align_celeba,共1.34G。生成器与判别器模型均采用简单的卷积结构,代码参考了pytorch官网。

建议对pytorch和神经网络原理还不熟悉的同学,可以先看下之前的文章了解下基础:pytorch基础教学简单实例(附代码)_Lizhi_Tech的博客-CSDN博客_pytorch实例


2.GAN原理

简而言之,生成对抗网络可以归纳为以下几个步骤:

  1. 随机噪声输入进生成器,生成虚假图片。

  1. 将带标签的虚假图片和真实图片输入进判别器进行更新,最大化 log(D(x)) + log(1 - D(G(z)))。

  1. 根据判别器的输出结果更新生成器,最大化 log(D(G(z)))。


3.代码

from __future__ import print_function
import random
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt

# 设置随机算子
manualSeed = 999
random.seed(manualSeed)
torch.manual_seed(manualSeed)

# 数据集位置
dataroot = "data/celeba"

# dataloader的核数
workers = 2

# Batch大小
batch_size = 128

# 图像缩放大小
image_size = 64

# 图像通道数
nc = 3

# 隐向量维度
nz = 100

# 生成器特征维度
ngf = 64

# 判别器特征维度
ndf = 64

# 训练轮数
num_epochs = 5

# 学习率
lr = 0.0002

# Adam优化器的beta系数
beta1 = 0.5

# gpu个数
ngpu = 1

# 加载数据集
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# 创建dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# 使用cpu还是gpu
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# 初始化权重
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# 生成器
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

# 实例化生成器并初始化权重
netG = Generator(ngpu).to(device)
netG.apply(weights_init)

# 判别器
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

# 实例化判别器并初始化权重
netD = Discriminator(ngpu).to(device)
netD.apply(weights_init)


# 损失函数
criterion = nn.BCELoss()

# 随机输入噪声
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# 真实标签与虚假标签
real_label = 1.
fake_label = 0.

# 创建优化器
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# 开始训练
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) 更新D: 最大化 log(D(x)) + log(1 - D(G(z)))
        ###########################
        # 使用真实标签的batch训练
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # 使用虚假标签的batch训练
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        # 更新D
        optimizerD.step()

        ############################
        # (2) 更新G: 最大化 log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)
        output = netD(fake).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        # 更新G
        optimizerG.step()

        # 输出训练状态
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\\tLoss_D: %.4f\\tLoss_G: %.4f\\tD(x): %.4f\\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # 保存每轮loss
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # 记录生成的结果
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

# loss曲线
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()


# 生成效果图
real_batch = next(iter(dataloader))

# 真实图像
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# 生成的虚假图像
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

4.结果

真实图像与生成图像

loss曲线


完整的代码与数据集可以在我的github上找到:lizhiTech/pytorch_GAN_simple_example: 介绍了使用pytorch框架构建生成对抗网络GAN来生成虚假图像的原理与简单实例代码 (github.com)

或者csdn下载:

我们也提供包括深度学习、计算机视觉、机器学习等其他方向的其他代码及辅导服务,有需求可以通过csdn私聊或github上的联系方式联系我们。

gan-生成对抗网络-生成手写数字(基于pytorch)(代码片段)

什么是GANGAN(GenerativeAdversarialNetwork),网络也如他的名字一样,有生成,有对抗,两个网络相互博弈。我们给两个网络起个名字,第一个网络用来生成数据命名为生成器(generator),另一个网络用来... 查看详情

利用tensorflow训练简单的生成对抗网络gan

...ets中提出来的。原理方面,对抗网络可以简单归纳为一个生成器(generator)和一个判断器(discriminator)之间博弈的过程。整个网络训练的过程中,两个模块的分工判断器,直观来看就是一个简单的神经网络结构,输入就是一副图像,... 查看详情

gan拟合高斯分布数据pytorch实现

参考技术AGAN本身是一种生成式模型,所以在数据生成上用的是最普遍的,最常见的是图片生成,常用的有DCGANWGAN,BEGAN。目前比较有意思的应用就是GAN用在图像风格迁移,图像降噪修复,图像超分辨率了,都有比较好的结果。目... 查看详情

gan-生成对抗网络(pytorch)合集--pixtopix-cyclegan(代码片段)

pixtopix(像素到像素)原文连接:https://arxiv.org/pdf/1611.07004.pdf输入一个域的图片转换为另一个域的图片(白天照片转成黑夜)如下图,输入标记图片,输出真实图片缺点就是训练集两个域的图片要一一对应,... 查看详情

翻车现场:我用pytorch和gan做了一个生成神奇宝贝的失败模型(代码片段)

前言神奇宝贝已经是一个家喻户晓的动画了,我们今天来确认是否可以使用深度学习为他自动创建新的Pokemon。我最终成功地使用了生成对抗网络(GAN)生成了类似Pokemon的图像,但是这个图像看起来并不像神奇宝贝。虽然这个尝... 查看详情

[pytorch系列-65]:生成对抗网络gan-图像生成开源项目pytorch-cyclegan-and-pix2pix-无监督图像生成cyclegan的基本原理

...cleGan网络的图像转换1.1项目在github上的源代码GitHub-junyanz/pytorch-CycleGAN-and-pix2pix:Image-to-ImageTranslationinPyTorch1.2 图像转换实例(1)环境的轮廓不变,更换颜色,风格(秋天-夏天)(2)马的轮廓不变,... 查看详情

pytorch深度学习50篇·······第七篇:gan生成对抗网络---pix2pix(代码片段)

可能看过看过我上两篇GAN和CGAN的朋友们都认为,mnist数据太简单了,也不太适合拿出去show,所以我们来一个复杂一点的,这次难度比之前两篇的难度又有所提升了,所以,请大家不要慌张,紧跟脚步&#x... 查看详情

[pytorch系列-74]:生成对抗网络gan-图像生成开源项目pytorch-cyclegan-and-pix2pix-pix2pix网络结构与代码实现详解(代码片段)

...;还需要与样本标签图片进行像素级的比较。1.2 代码来源pytorch-CycleGAN-and-pix2pix\\models\\pix2pix_model.py1.3网络结构代码解读def__init__(self,opt):"""Initializethepix2pixclass.Pa 查看详情

pytorchnote45生成对抗网络(gan)(代码片段)

PytorchNote45生成对抗网络(GAN)文章目录PytorchNote45生成对抗网络(GAN)GANsDiscriminatorNetworkGeneratorNetwork简单版本的生成对抗网络判别网络生成网络全部笔记的汇总贴:PytorchNote快乐星球2014年,深度学习三巨头... 查看详情

《生成对抗网络gan的原理与应用专题》笔记

...是GAN框架简述GAN全称是GenerativeAdversarialNets,中文叫做“生成对抗网络”。在GAN中有2个网络,一个网络用于生成数据,叫做“生成器”。另一个网络用于判别生成数据是否接近于真实,叫做“判别器”。下图展示了最简单的GAN的... 查看详情

深度卷积生成对抗网络

深度卷积生成对抗网络DeepConvolutionalGenerativeAdversarialNetworksGANs如何工作的基本思想。可以从一些简单的,易于抽样的分布,如均匀分布或正态分布中提取样本,并将其转换成与某些数据集的分布相匹配的样本。虽然例子匹配一个... 查看详情

[pytorch系列-75]:生成对抗网络gan-图像生成开源项目pytorch-cyclegan-and-pix2pix-cyclegan网络结构与代码实现详解(代码片段)

...:G_A2B,D_A2B,G_B2A,D_B2A,后两个是新增的。1.2 代码来源pytorch-CycleGAN-and-pix2pix\\models\\cycle_gan_model.py1.3网络结构代码解读def__init__(self, 查看详情

不要怂,就是gan(生成式对抗网络)

前面我们了解了GAN的原理,下面我们就来用TensorFlow搭建GAN(严格说来是DCGAN,如无特别说明,本系列文章所说的GAN均指DCGAN),如前面所说,GAN分为有约束条件的GAN,和不加约束条件的GAN,我们先来搭建一个简单的MNIST数据集上... 查看详情

深度学习gan生成对抗网络-1010格式数据生成简单案例(代码片段)

...式,而是使用一个非常简单的案例来帮助我们了解GAN生成对抗网络。二、GAN概念生成对抗网络(GenerativeAdversarialNetworks,GAN)包含生成器(Generator)和鉴别器(Discriminator)两个神经网络。生成器用... 查看详情

gan-生成对抗神经网络(pytorch)-合集gan-dcgan-cgan(代码片段)

原生GAN(GenerativeAdversarialNets)训练过程也是老三步了,再啰嗦一遍:使用真实图片训练辨别器,标签为真使用生成器生成的图片训练判别器,标签为假,此时图片使用生成器计算得来的,喂给判别... 查看详情

gan-生成对抗神经网络(pytorch)-合集gan-dcgan-cgan(代码片段)

原生GAN(GenerativeAdversarialNets)训练过程也是老三步了,再啰嗦一遍:使用真实图片训练辨别器,标签为真使用生成器生成的图片训练判别器,标签为假,此时图片使用生成器计算得来的,喂给判别... 查看详情

pytorch入门实战:基于gan生成简单的动漫人物头像(代码片段)

...础知识,可参考我的学习笔记或观看李宏毅老师课程Pytorch中DataLoader和Dataset的基本用法反卷积通俗详细解析与nn.ConvTranspose2d重要参数解释TensorBoard快速入门(Pytorch使用TensorBoard)本文内容本文参考李彦宏老师2021年度的... 查看详情

[pytorch系列-63]:生成对抗网络gan-图像生成开源项目pytorch-cyclegan-and-pix2pix-代码总体架构与总体学习思路

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121940011目录第1章理论概述1.1 普通GAN,pix2pix,CycleGAN和pix2pixHD的演变过程... 查看详情