pytorch入门实战:基于gan生成简单的动漫人物头像(代码片段)

iioSnail iioSnail     2022-11-29     760

关键词:

本文涉及知识点

  1. GAN网络基础知识,可参考我的学习笔记或观看李宏毅老师课程
  2. Pytorch中DataLoader和Dataset的基本用法
  3. 反卷积通俗详细解析与nn.ConvTranspose2d重要参数解释
  4. TensorBoard快速入门(Pytorch使用TensorBoard)

本文内容

本文参考李彦宏老师2021年度的GAN作业06,训练一个生成动漫人物头像的GAN网络。本篇是入门篇,所以使用最简单的GAN网络,所以生成的动漫人物头像也较为模糊。最终效果为(我这边只训练了40个epoch):

全局参数

首先导入需要用到的包:

import os
import sys

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm 
from torch.utils.tensorboard import SummaryWriter

设置一些全局参数:

batch_size = 64
num_workers = 2
n_epoch = 100
z_dim = 100 # 噪音向量的维度
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 模型的保存路径,在Google Colab中,若挂在了Google Drive,则模型会被保存到Google Drive上
ckpt_dir = 'drive/MyDrive/models' 
faces_path = "faces"  # 数据集所在的目录

print("Device: ", device) # 打印一下设备,防止训练半天发现是在CPU上做的
Device:  cuda

数据集

数据集是若干动漫人物头像,下载链接地址如下:

https://pan.baidu.com/s/1zsJJJapFLr1zWWhgGol-aA 提取码:2k4z

下载好之后,将其解压到当前目录下,最终为:

faces/
├── 1.jpg
├── 2.jpg
├── 3.jpg
...

工具函数

这里定义一个用于清理输出的工具类,用于训练过程中清理输出,要不然太多了。

def clear_output():
    """
    清理Jupyter Notebook中的输出
    """
    os.system('cls' if os.name == 'nt' else 'clear')
    if 'ipykernel' in sys.modules:
        from IPython.display import clear_output as clear
        clear()

数据预处理

定义Dataset, 我这里将头像图片缩放到64x64的大小,并对其进行了标准化:

class CrypkoDataset(Dataset):
    def __init__(self, img_path='./faces'):
        self.fnames = [img_path + '/' + img for img in os.listdir(img_path)]

        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
            # 这里将图片缩放到了均值为0.5,方差为0.5的区间,本质是执行了 (x-0.5)/0.5
            # 详情可参考:https://blog.csdn.net/zhaohongfei_358/article/details/117910661
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
        ])

        self.num_samples = len(self.fnames)


    def __getitem__(self, idx):
        fname = self.fnames[idx]
        img = torchvision.io.read_image(fname)
        img = self.transform(img)
        return img


    def __len__(self):
        return self.num_samples
dataset = CrypkoDataset(faces_path)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

这里简单测试下各个方法:

dataset.__getitem__(0).size(), len(dataset)
(torch.Size([3, 64, 64]), 71314)

可以看到,我们将图片成功缩放到64x64大小,一共有71314张图片。

接下来展示几个图片,看看效果:

images = [(dataset[i] + 1) / 2 for i in range(16)]  # 拿出16张图片
grid_img = torchvision.utils.make_grid(images, nrow=4)  # 将其组合成一个4x4的网格
plt.figure(figsize=(6, 6))
plt.imshow(grid_img.permute(1, 2, 0)) # plt接收的图片通道要在最后,所以permute一下
plt.show()

关于 (dataset[i] + 1) / 2 的简单解释:由于上面执行了标准化,即 y=(x-0.5)/0.5,所以这里要将其变回x,所以执行了 x=0.5y+0.5=(y+1)/2


定义模型

数据集准备完毕后,开始定义我们的模型,GAN需要定义Generator和Discriminator,Generator用于生成图片,Discriminator用来鉴别图片是生成的还是真实的,本文使用DCGAN(Deep Convolutional GAN)。

Generator

class Generator(nn.Module):
    """
    输入Shape为(N, in_dim),N为batch_size, in_dim是随机向量的维度
    输出Shape为(N, 3, 64, 64),即生成N张64x64的彩色图像
    """

    def __init__(self, in_dim, dim=64):
        super(Generator, self).__init__()

        def dconv_bn_relu(in_dim, out_dim):
            return nn.Sequential(
                nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
                                   padding=2, output_padding=1, bias=False),
                nn.BatchNorm2d(out_dim),
                nn.ReLU()
            )

        # 1. 先用线性层将随机向量变成 dim*8 个通道,大小为4x4的图片
        self.l1 = nn.Sequential(
            nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),
            nn.BatchNorm1d(dim * 8 * 4 * 4),
            nn.ReLU()
        )

        # 2. 然后就一直反卷积,不断的将图片变大,同时通道不断减小,最终变成一个3通道,64x64大小的图片
        self.l2_5 = nn.Sequential(
            dconv_bn_relu(dim * 8, dim * 4),
            dconv_bn_relu(dim * 4, dim * 2),
            dconv_bn_relu(dim * 2, dim),
            nn.ConvTranspose2d(dim, 3, 5, 2, padding=2, output_padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        y = self.l1(x)
        y = y.view(y.size(0), -1, 4, 4)
        y = self.l2_5(y)
        return y

Discriminator

class Discriminator(nn.Module):
    """
    输入Shape为(N, 3, 64, 64),即N张64x64的彩色图片
    输出Shape为(N,), 即这N个图片每张图片的真实率,越接近1表示Discriminator越觉得它是真的
    """

    def __init__(self, in_dim=3, dim=64): # 注意这里的in_dim是指的图片的通道数,所以是3
        super(Discriminator, self).__init__()

        def conv_bn_lrelu(in_dim, out_dim):
            return nn.Sequential(
                nn.Conv2d(in_dim, out_dim, 5, 2, 2),
                nn.BatchNorm2d(out_dim),
                nn.LeakyReLU(0.2),
            )

        # 就是一堆卷积一直卷,把原始的图片最终卷成一个数字
        self.ls = nn.Sequential(
            nn.Conv2d(in_dim, dim, 5, 2, 2),
            nn.LeakyReLU(0.2),
            conv_bn_lrelu(dim, dim * 2),
            conv_bn_lrelu(dim * 2, dim * 4),
            conv_bn_lrelu(dim * 4, dim * 8),
            nn.Conv2d(dim * 8, 1, 4),
            nn.Sigmoid(),
        )

    def forward(self, x):
        y = self.ls(x)
        y = y.view(-1)
        return y
G = Generator(in_dim=z_dim)
D = Discriminator()
G = G.to(device)
D = D.to(device)

由于Discriminator做的是二分类问题,所以这里使用Binary Cross Entropy

criterion = nn.BCELoss()
opt_D = torch.optim.Adam(D.parameters(), lr=learning_rate)
opt_G = torch.optim.Adam(G.parameters(), lr=learning_rate)

训练模型

这里使用tensorboard记录loss变化和生成的图片:

writer = SummaryWriter()

然后就可以启动tensorboard了:

tensorboard --logdir runs

开始训练:

steps = 0
log_after_step = 50 # 多少步记录一次Loss

# 用于评估阶段的z向量
z_sample = Variable(torch.randn(100, z_dim)).to(device)

for e, epoch in enumerate(range(n_epoch)):
    total_loss_D = 0
    total_loss_G = 0

    for i, data in enumerate(tqdm(dataloader, desc='Epoch : '.format(e))):
        imgs = data
        imgs = imgs.to(device)

        # 重新获取batch_size,因为到最后一组的时候,可能凑不够
        batch_size = imgs.size(0)

        # ============================================
        #  训练Discriminator
        # ============================================
        # 1. 得到一批随机的噪音向量 z
        z = Variable(torch.randn(batch_size, z_dim)).to(device)
        # 2. 得到真实(real)的图片
        r_imgs = Variable(imgs).to(device)
        # 3. 使用 Generator生成一批假(fake)图片
        f_imgs = G(z)

        # 构建标签,真实图片的标签都为1,假图片的标签都为0
        r_label = torch.ones((batch_size, )).to(device)
        f_label = torch.zeros((batch_size, )).to(device)

        # 用Discriminator对真实图片和假图片进行判别
        r_logit = D(r_imgs.detach())
        f_logit = D(f_imgs.detach())

        # 计算Discriminator的损失
        r_loss = criterion(r_logit, r_label)
        f_loss = criterion(f_logit, f_label)
        loss_D = (r_loss + f_loss) / 2
        total_loss_D += loss_D

        # 对Discriminator进行反向传播
        D.zero_grad()
        loss_D.backward()
        opt_D.step()

        # ============================================
        # 训练Generator
        # ============================================
        # 1. 生成N张假图片
        z = Variable(torch.randn(batch_size, z_dim)).to(device)
        f_imgs = G(z)

        # 2. 让Discriminator判别这些假图片
        f_logit = D(f_imgs)

        # 3. 计算损失,这里Generator是希望图片越真越好,所以参数是f_logit和r_label
        loss_G = criterion(f_logit, r_label)
        total_loss_G += loss_G

        # 对Generator进行反向传播
        G.zero_grad()
        loss_G.backward()
        opt_G.step()

        steps += 1

        if steps % log_after_step == 0:
            writer.add_scalars("loss", 
                "Loss_D": total_loss_D / log_after_step,
                "Loss_G": total_loss_G / log_after_step
            , global_step=steps)

    # 清理之前的输出
    clear_output()

    # 每一个epoch后,生成一张一组图片看看效果
    G.eval()

    # 用Generator生成图片,并进行去除标准化,然后保存到logs目录下
    f_imgs_sample = (G(z_sample).data + 1) / 2.0
    if not os.path.exists('logs'):
        os.makedirs('logs')
    filename = os.path.join('logs', f'Epoch_epoch + 1:03d.jpg')
    # 将生成的图片保存下来
    torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)
    print(f' | Save some samples to filename.')

    # 展示一下生成的图片
    grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)
    plt.figure(figsize=(10, 10))
    plt.imshow(grid_img.permute(1, 2, 0))
    plt.show()

    # Tensorboard记录一下生成的图片
    writer.add_image("Generated_Images", grid_img, global_step=steps)

    # 将Generator重新调整为训练模式
    G.train()

    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    # 每5个epoch保存一次模型
    if (e + 1) % 5 == 0 or e == 0:
        # Save the checkpoints.
        torch.save(G.state_dict(), os.path.join(ckpt_dir, 'G_.pth'.format(steps)))
        torch.save(D.state_dict(), os.path.join(ckpt_dir, 'D_.pth'.format(steps)))

这是第40个epoch后的输出,我在之后就停止了。

我们来观察下Tensorboard面板:


红色为Generator的损失,蓝色为Discriminator的损失,从这张图我们可以看到两个现象:

  1. Discriminator的损失明显小于Generator:这是GAN常见的一个问题,原因是Discriminator的任务比Generator简单的多,毕竟学会分辨真假画可比学会画画简单多了。这样就会导致Discriminator能给Generator提供的信息太少,进而导致Generator无法收敛
  2. Generator的损失不断的震荡(其实Discriminator也在震荡):其实震荡才是正常的,因为Generator的损失是“是否能骗过Discriminator”,第一次他没骗过(损失高),然后G进化了,就能骗过了(损失低),然后Discrimnator也进化了,然后G又骗不过了(损失高),然后就这样依次反复。 Discriminator同理。反过来,如果Generator的损失不断下降,那么就要考虑是不是Generator太容易骗过Discriminator了,然后想着优化Discriminator。

我在训练过程也记录了图片的变化:


可以看到,Generator生成的图片质量越来越好

模型使用

训练完成后,我们来尝试使用下模型:

G.eval()
inputs = torch.rand(1, 100).to(device)
outputs = G(inputs)
outputs = (outputs.data + 1) / 2.0
plt.figure(figsize=(5, 5))
plt.imshow(outputs[0].cpu().permute(1, 2, 0))
plt.show()

。。。 有点丑,但至少还能看出来是个妹子

深度学习理论与实战pytorch实现

课程目录:01.预备内容(入门)02.Python基础(入门)03.PyTorch基础(入门)04.神经网络(进阶)05.卷积神经网络(进阶)06.循环神经网络(进阶)07.生成对抗网络GAN(进阶)08.强化学习(进阶)09.毕业项目 下载地址:深度学习理... 查看详情

pytorch搭建基本的gan模型及训练过程(代码片段)

...数绘图函数GAN的训练输出整体代码参考资料概述本文通过Pytorch搭建基本的GAN模型结构,并通过torchvision的MNIST数据集进行测试。对于GAN模型的基本结构及公式的理解可以看前一篇博客:GAN的理论知识及公式的理解下文的实... 查看详情

猿创征文|animeganv2照片动漫化:如何基于pytorch和神经网络给girlfriend制作漫画风头像?(代码片段)

文章目录前言一、基于GAN实现漫画风格实现原理1.1、传统漫画风格迁移工具的不足1.2、基于生成对抗网络(GAN)的漫画风格迁移工具二、AnimeGANv2照片动漫化2.1、与AnimeGAN的对比2.2、AnimeGANv2效果及项目介绍三、本次案例部署及实验平... 查看详情

猿创征文|animeganv2照片动漫化:如何基于pytorch和神经网络给girlfriend制作漫画风头像?(代码片段)

文章目录前言一、基于GAN实现漫画风格实现原理1.1、传统漫画风格迁移工具的不足1.2、基于生成对抗网络(GAN)的漫画风格迁移工具二、AnimeGANv2照片动漫化2.1、与AnimeGAN的对比2.2、AnimeGANv2效果及项目介绍三、本次案例部署及实验平... 查看详情

gan-生成对抗网络-生成手写数字(基于pytorch)(代码片段)

什么是GANGAN(GenerativeAdversarialNetwork),网络也如他的名字一样,有生成,有对抗,两个网络相互博弈。我们给两个网络起个名字,第一个网络用来生成数据命名为生成器(generator),另一个网络用来... 查看详情

pytorch实现简单的生成对抗网络gan

...sp;   以下程序主要来自廖星宇的《深度学习之PyTorch》的第六章,本文对原代码进行了改进:importtorchfromtorchimportnnimporttorchvision.transformsastfsfromtorch.utils.dataimportDataLoaderfromtorchvision.datasetsimportMNISTimportnumpyasnpimportmatplotlib.... 查看详情

pytorch生成对抗网络gan的基础教学简单实例(附代码数据集)

1.简介这篇文章主要是介绍了使用pytorch框架构建生成对抗网络GAN来生成虚假图像的原理与简单实例代码。数据集使用的是开源人脸图像数据集img_align_celeba,共1.34G。生成器与判别器模型均采用简单的卷积结构,代码参考... 查看详情

pytorch入门与实战----pytorch入门(代码片段)

1.深度学习框架 pytorch与其他框架的比较pytorch的学习方法:课程安排:PyTorch是一个基于Python的科学计算库,它有以下特点:类似于NumPy,但是它可以使用GPU可以用它定义深度学习模型,可以灵活地进行深度学习模型的训练和使... 查看详情

生成式模型入门:gan与vae——“以假乱真”的哲学

参考技术A在图像处理中,比较常见的任务有识别、检测、追踪等,这些任务的模型通常在训练阶段通过参数估计学得如何提取输入图像的特征,并建立输入图像与输出之间的映射,在应用阶段之间提取输入图像的特征,以得到... 查看详情

实战教程|用让新海诚本人惊讶的ai模型制作属于你的动漫视频

大家好,我是Charmve!本文将介绍如何使用GAN模型来生成属于你自己的动漫风格的视频,为自己、喜欢的菇凉或者调皮可爱孩子生成一个别具一格的动漫风格的视频。本文操作难度较低,适合想要试玩了解GAN模型的同学。可以同... 查看详情

实战教程|用让新海诚本人惊讶的ai模型制作属于你的动漫视频

大家好,我是Charmve!本文将介绍如何使用GAN模型来生成属于你自己的动漫风格的视频,为自己、喜欢的菇凉或者调皮可爱孩子生成一个别具一格的动漫风格的视频。本文操作难度较低,适合想要试玩了解GAN模型的同学。可以同... 查看详情

翻车现场:我用pytorch和gan做了一个生成神奇宝贝的失败模型(代码片段)

前言神奇宝贝已经是一个家喻户晓的动画了,我们今天来确认是否可以使用深度学习为他自动创建新的Pokemon。我最终成功地使用了生成对抗网络(GAN)生成了类似Pokemon的图像,但是这个图像看起来并不像神奇宝贝。虽然这个尝... 查看详情

gan拟合高斯分布数据pytorch实现

参考技术AGAN本身是一种生成式模型,所以在数据生成上用的是最普遍的,最常见的是图片生成,常用的有DCGANWGAN,BEGAN。目前比较有意思的应用就是GAN用在图像风格迁移,图像降噪修复,图像超分辨率了,都有比较好的结果。目... 查看详情

对比学习:《深度学习之pytorch》《pytorch深度学习实战》+代码

PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉、自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的AllenNLP,用于概率图模型的Pyro,扩展了PyTorch的功能。通... 查看详情

将 torch.backward() 用于 GAN 生成器时,为啥 Pytorch 中的判别器损失没有变化?

】将torch.backward()用于GAN生成器时,为啥Pytorch中的判别器损失没有变化?【英文标题】:Whenusingtorch.backward()foraGANsgenerator,whydoesn\'tdiscriminatorlosseschangeinPytorch?将torch.backward()用于GAN生成器时,为什么Pytorch中的判别器损失没有变化?... 查看详情

gan实战之pytorch使用pix2pixgan生成建筑物图片(代码片段)

pix2pixGAN的相关原理说明参考:GAN笔记--pix2pixGAN网络原理介绍以及论文解读一、数据集加载    数据集下载:pix2pixGAN训练数据集,建筑物数据集解压分成后有两个文件夹base和extend,我们将base作为训练集,exten... 查看详情

deepfake入门了解

参考技术ADeepfake,是由“deepmachinelearning”(深度机器学习)和“fakephoto”(假照片)组合而成,本质是一种深度学习模型在图像合成、替换领域的技术框架,属于深度图像生成模型的一次成功应用。其实该技术最早版本在2018年... 查看详情

有没有pytorch的高级应用教程推荐一下

深度学习与PyTorch入门实战教程百度网盘免费资源在线学习   链接:https://pan.baidu.com/s/1BQtdabkn-aj0SXVv3x62ww 提取码:r93i  深度学习与PyTorch入门实战教程9.卷积神经网络CNN8.过拟合7.神经网络与全连接层6.随机梯度下... 查看详情