pytorch搭建基本的gan模型及训练过程(代码片段)

半岛铁子_ 半岛铁子_     2022-12-08     119

关键词:

文章目录

概述

本文通过Pytorch搭建基本的GAN模型结构,并通过 torchvision 的 MNIST 数据集进行测试。
对于GAN模型的基本结构及公式的理解可以看前一篇博客:
GAN的理论知识及公式的理解
下文的实现完全对照这一篇博客的基本理论。

代码实战

代码是基于Pytorch环境创建,需要先安装Pytorch环境
Pytorch环境搭建教程链接:
Pytorch搭建教程

导包

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms

数据准备

# 对数据做归一化 (-1, 1)
transform = transforms.Compose([
    transforms.ToTensor(),         # 将数据转换成Tensor格式,channel, high, witch,数据在(0, 1)范围内
    transforms.Normalize(0.5, 0.5) # 通过均值和方差将数据归一化到(-1, 1)之间
])

# 下载数据集
train_ds = torchvision.datasets.MNIST('data',
                                      train=True,
                                      transform=transform,
                                      download=True)
                                      
# 设置dataloader
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)

# 返回一个批次的数据
imgs, _ = next(iter(dataloader))

# imgs的大小
imgs.shape

定义生成器

# 输入是长度为 100 的 噪声(正态分布随机数)
# 输出为(1, 28, 28)的图片
# linear 1 :   100----256
# linear 2:    256----512
# linear 2:    512----28*28
# reshape:     28*28----(1, 28, 28)

class Generator(nn.Module): #创建的 Generator 类继承自 nn.Module
    def __init__(self): # 定义初始化方法
        super(Generator, self).__init__() #继承父类的属性
        self.main = nn.Sequential( #使用Sequential快速创建模型
                                  nn.Linear(100, 256),
                                  nn.ReLU(),
                                  nn.Linear(256, 512),
                                  nn.ReLU(),
                                  nn.Linear(512, 28*28),
                                  nn.Tanh()                     # 输出层使用Tanh()激活函数,使输出-1, 1之间
        )
    def forward(self, x):              # 定义前向传播 x 表示长度为100 的noise输入
        img = self.main(x)
        img = img.view(-1, 28, 28) #将img展平,转化成图片的形式,channel为1可写可不写
        return img

定义判别器

## 输入为(1, 28, 28)的图片  输出为二分类的概率值,输出使用sigmoid激活 0-1
# BCEloss计算交叉熵损失

# nn.LeakyReLU   f(x) : x>0 输出 x, 如果x<0 ,输出 a*x  a表示一个很小的斜率,比如0.1
# 判别器中一般推荐使用 LeakyReLU

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
                                  nn.Linear(28*28, 512), #输入是28*28的张量,也就是图片
                                  nn.LeakyReLU(), # 小于0的时候保存一部分梯度
                                  nn.Linear(512, 256),
                                  nn.LeakyReLU(),
                                  nn.Linear(256, 1), # 二分类问题,输出到1上
                                  nn.Sigmoid()
        )
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.main(x)
        return x

初始化模型、优化器及损失计算函数

# 定义设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 初始化模型
gen = Generator().to(device)
dis = Discriminator().to(device)
# 优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=0.0001)
g_optim = torch.optim.Adam(gen.parameters(), lr=0.0001)
# 损失函数
loss_fn = torch.nn.BCELoss()

绘图函数

def gen_img_plot(model, epoch, test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i+1)
        plt.imshow((prediction[i] + 1)/2) # 确保prediction[i] + 1)/2输出的结果是在0-1之间
        plt.axis('off')
    plt.show()
    
test_input = torch.randn(16, 100, device=device)

GAN的训练

# 保存每个epoch所产生的loss值
D_loss = []
G_loss = []

# 训练循环
for epoch in range(20): #训练20个epoch
   d_epoch_loss = 0 # 初始损失值为0
   g_epoch_loss = 0
   # len(dataloader)返回批次数,len(dataset)返回样本数
   count = len(dataloader)
   # 对dataloader进行迭代
   for step, (img, _) in enumerate(dataloader): # enumerate加序号
       img = img.to(device) #将数据上传到设备
       size = img.size(0) # 获取每一个批次的大小
       random_noise = torch.randn(size, 100, device=device)  # 随机噪声的大小是size个
       
       d_optim.zero_grad() # 将判别器前面的梯度归0
       
       real_output = dis(img)      # 判别器输入真实的图片,real_output是对真实图片的预测结果 
       
       # 得到判别器在真实图像上的损失
       # 判别器对于真实的图片希望输出的全1的数组,将真实的输出与全1的数组进行比较
       d_real_loss = loss_fn(real_output, 
                             torch.ones_like(real_output))      
       d_real_loss.backward() # 求解梯度
       
       
       gen_img = gen(random_noise)    
       # 判别器输入生成的图片,fake_output是对生成图片的预测
       # 优化的目标是判别器,对于生成器的参数是不需要做优化的,需要进行梯度阶段,detach()会截断梯度,
       # 得到一个没有梯度的Tensor,这一点很关键
       fake_output = dis(gen_img.detach()) 
       # 得到判别器在生成图像上的损失
       d_fake_loss = loss_fn(fake_output, 
                             torch.zeros_like(fake_output))      
       d_fake_loss.backward() # 求解梯度
       
       d_loss = d_real_loss + d_fake_loss # 判别器总的损失等于两个损失之和
       d_optim.step() # 进行优化
       
       g_optim.zero_grad() # 将生成器的所有梯度归0
       fake_output = dis(gen_img) # 将生成器的图片放到判别器中,此时不做截断,因为要优化生成器
       # 生层器希望生成的图片被判定为真
       g_loss = loss_fn(fake_output, 
                        torch.ones_like(fake_output))      # 生成器的损失
       g_loss.backward() # 计算梯度
       g_optim.step() # 优化
       
       # 将损失累加到定义的数组中,这个过程不需要计算梯度
       with torch.no_grad():
           d_epoch_loss += d_loss
           g_epoch_loss += g_loss
     
   # 计算每个epoch的平均loss,仍然使用这个上下文关联器
   with torch.no_grad():
       # 计算平均的loss值
       d_epoch_loss /= count
       g_epoch_loss /= count
       # 将平均loss放入到loss数组中
       D_loss.append(d_epoch_loss.item())
       G_loss.append(g_epoch_loss.item())
       # 打印当前的epoch
       print('Epoch:', epoch)
       # 调用绘图函数
       gen_img_plot(gen, epoch, test_input)

输出

Epoch: 0

…(省略中间的迭代输出)

Epoch: 19

总共做了20次的迭代,可以看出,随着迭代次数的增加,生成的图片质量越来越好。

整体代码


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms

# 对数据做归一化 (-1, 1)
transform = transforms.Compose([
    transforms.ToTensor(),         # 将数据转换成Tensor格式,channel, high, witch,数据在(0, 1)范围内
    transforms.Normalize(0.5, 0.5) # 通过均值和方差将数据归一化到(-1, 1)之间
])

# 下载数据集
train_ds = torchvision.datasets.MNIST('data',
                                      train=True,
                                      transform=transform,
                                      download=True)
                                      
# 设置dataloader
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)

# 返回一个批次的数据
imgs, _ = next(iter(dataloader))

# imgs的大小
imgs.shape

# 输入是长度为 100 的 噪声(正态分布随机数)
# 输出为(1, 28, 28)的图片
# linear 1 :   100----256
# linear 2:    256----512
# linear 2:    512----28*28
# reshape:     28*28----(1, 28, 28)

class Generator(nn.Module): #创建的 Generator 类继承自 nn.Module
    def __init__(self): # 定义初始化方法
        super(Generator, self).__init__() #继承父类的属性
        self.main = nn.Sequential( #使用Sequential快速创建模型
                                  nn.Linear(100, 256),
                                  nn.ReLU(),
                                  nn.Linear(256, 512),
                                  nn.ReLU(),
                                  nn.Linear(512, 28*28),
                                  nn.Tanh()                     # 输出层使用Tanh()激活函数,使输出-1, 1之间
        )
    def forward(self, x):              # 定义前向传播 x 表示长度为100 的noise输入
        img = self.main(x)
        img = img.view(-1, 28, 28) #将img展平,转化成图片的形式,channel为1可写可不写
        return img

## 输入为(1, 28, 28)的图片  输出为二分类的概率值,输出使用sigmoid激活 0-1
# BCEloss计算交叉熵损失

# nn.LeakyReLU   f(x) : x>0 输出 x, 如果x<0 ,输出 a*x  a表示一个很小的斜率,比如0.1
# 判别器中一般推荐使用 LeakyReLU

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
                                  nn.Linear(28*28, 512), #输入是28*28的张量,也就是图片
                                  nn.LeakyReLU(), # 小于0的时候保存一部分梯度
                                  nn.Linear(512, 256),
                                  nn.LeakyReLU(),
                                  nn.Linear(256, 1), # 二分类问题,输出到1上
                                  nn.Sigmoid()
        )
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.main(x)
        return x

# 定义设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 初始化模型
gen = Generator().to(device)
dis = Discriminator().to(device)
# 优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=0.0001)
g_optim = torch.optim.Adam(gen.parameters(), lr=0.0001)
# 损失函数
loss_fn = torch.nn.BCELoss()

def gen_img_plot(model, epoch, test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i+1)
        plt.imshow((prediction[i] + 1)/2) # 确保prediction[i] + 1)/2输出的结果是在0-1之间
        plt.axis('off')
    plt.show()
    
test_input = torch.randn(16, 100, device=device)

 # 保存每个epoch所产生的loss值
D_loss = []
G_loss = []

# 训练循环
for epoch in range(20): #训练20个epoch
    d_epoch_loss = 0 # 初始损失值为0
    g_epoch_loss = 0
    # len(dataloader)返回批次数,len(dataset)返回样本数
    count = len(dataloader)
    # 对dataloader进行迭代
    for step, (img, _) in enumerate(dataloader): # enumerate加序号
        img = img.to(device) #将数据上传到设备
        size = img.size(0) # 获取每一个批次的大小
        random_noise = torch.randn(size, 100, device=device)  # 随机噪声的大小是size个
        
        d_optim.zero_grad() # 将判别器前面的梯度归0
        
        real_output = dis(img)      # 判别器输入真实的图片,real_output是对真实图片的预测结果 
        
        # 得到判别器在真实图像上的损失
        # 判别器对于真实的图片希望输出的全1的数组,将真实的输出与全1的数组进行比较
        d_real_loss = loss_fn(real_output, 
                              torch.ones_like(real_output))      
        d_real_loss.backward() # 求解梯度
        
        
        gen_img = gen(random_noise)    
        # 判别器输入生成的图片,fake_output是对生成图片的预测
        # 优化的目标是判别器,对于生成器的参数是不需要做优化的,需要进行梯度阶段,detach()会截断梯度,
        # 得到一个没有梯度的Tensor,这一点很关键
        fake_output = dis(gen_img.detach()) 
        # 得到判别器在生成图像上的损失
        d_fake_loss = loss_fn(fake_output, 
                              torch.zeros_like(fake_output))      
        d_fake_loss.backward() # 求解梯度
        
        d_loss = d_real_loss + d_fake_loss # 判别器总的损失等于两个损失之和
        d_optim.step() # 进行优化
        
        g_optim.zero_grad() # 将生成器的所有梯度归0
        fake_output = dis(gen_img) # 将生成器的图片放到判别器中,此时不做截断,因为要优化生成器
        # 生层器希望生成的图片被判定为真
        g_loss = loss_fn(fake_output, 
                         torch.ones_like(fake_output))      # 生成器的损失
        g_loss.backward() # 计算梯度
        g_optim.step() # 优化
        
        # 将损失累加到定义的数组中,这个过程不需要计算梯度
        with torch.no_grad():
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss
      
    # 计算每个epoch的平均loss,仍然使用这个上下文关联器
    with torch.no_grad():
        # 计算平均的loss值
        d_epoch_loss /= count
        g_epoch_loss /= count
        # 将平均loss放入到loss数组中
        D_loss.append(d_epoch_loss.item())
        G_loss.append(g_epoch_loss.item())
        # 打印当前的epoch
        print('Epoch:', epoch)
        # 调用绘图函数
        gen_img_plot(gen, epoch, test_input)

参考资料

[1] https://www.bilibili.com/video/BV1xm4y1X7KZ
[2] https://blog.csdn.net/hshudoudou/article/details/126922562?spm=1001.2014.3001.5502

pytorch模型保存加载与续训练(代码片段)

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题🍊往期回顾:对抗生成网络GAN系列——GAN原理及手写数字生成小案例  对抗生成网络GAN系列——DCGAN简介及人脸图像生成案例🍊近期目标:... 查看详情

翻车现场:我用pytorch和gan做了一个生成神奇宝贝的失败模型(代码片段)

前言神奇宝贝已经是一个家喻户晓的动画了,我们今天来确认是否可以使用深度学习为他自动创建新的Pokemon。我最终成功地使用了生成对抗网络(GAN)生成了类似Pokemon的图像,但是这个图像看起来并不像神奇宝贝。虽然这个尝... 查看详情

[pytorch系列-72]:生成对抗网络gan-图像生成开源项目pytorch-cyclegan-and-pix2pix-使用预训练模型训练cyclegan模型(代码片段)

...本思路1.3训练方式第2章测试步骤第1步:下载或克隆pytorch-CycleGAN-and-pix2pix所有代码第2步:切换当前目录第3步:安装依赖文件(可视化工具)第4步:下载CycleGAN数据集第5步:下载预训练模型第6步:... 查看详情

[pytorch系列-66]:生成对抗网络gan-图像生成开源项目pytorch-cyclegan-and-pix2pix-使用预训练模型测试pix2pix模型(代码片段)

...路1.2本章基本思路第2章测试步骤第1步:下载或克隆pytorch-CycleGAN-and-pix2pix所有代码第2步:切换当前目录第3步:安装依赖文件(可视化工具)第4步:下载pix2pix数据集第5步:下载预训练模型第6步:... 查看详情

[pytorch系列-71]:生成对抗网络gan-图像生成开源项目pytorch-cyclegan-and-pix2pix-使用预训练模型训练pix2pix模型(代码片段)

...本思路1.3训练方式第2章测试步骤第1步:下载或克隆pytorch-CycleGAN-and-pix2pix所有代码第2步:切换当前目录第3步:安装依赖文件(可视化工具)第4步:下载pix2pix数据集第5步:下载预训练模型第6步:... 查看详情

[pytorch系列-68]:生成对抗网络gan-图像生成开源项目pytorch-cyclegan-and-pix2pix-使用预训练模型测试cyclegan模型(代码片段)

...路1.2本章基本思路第2章测试步骤第1步:下载或克隆pytorch-CycleGAN-and-pix2pix所有代码第2步:切换当前目录第3步:安装依赖文件(可视化工具)第4步:下载CycleGAN数据集第5步:下载预训练模型第6步:... 查看详情

[pytorch系列-61]:生成对抗网络gan-基本原理-自动生成手写数字案例分析(代码片段)

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121914862目录第1章基本原理第2章准备条件第3章数据集3.1自动下载minist数... 查看详情

中公的深度学习培训怎么样?有人了解吗?

...的分布式处理及项目实战多GPU并行实现分布式并行的环境搭建分布式并行实现第六阶段深度强化学习及项目实战强化学习介绍智能体Agent的深度决策机制(上)智能体Agent的深度决策机制(中)智能体Agent的深度决策机制(下)第... 查看详情

pytorch自定义数据集模型训练流程(代码片段)

文章目录Pytorch模型自定义数据集训练流程1、任务描述2、导入各种需要用到的包3、分割数据集4、将数据转成pytorch标准的DataLoader输入格式5、导入预训练模型,并修改分类层6、开始模型训练7、利用训好的模型做预测Pytorch模... 查看详情

pytorch笔记-diffusionmodel源码开发(代码片段)

DiffusionModel的效果如下:源码如下:选择一个数据集确定超参数的值确定扩散过程任意时刻的采样值演示原始数据分布加噪100步后的效果编写拟合逆扩散过程高斯分布的模型编写训练的误差函数编写逆扩散采样函数(inference过程)... 查看详情

加载pytorch中的预训练模型及部分结构的导入(代码片段)

torchvision.modelmodel子包中包含了用于处理不同任务的经典模型的定义,包括:图像分类、像素级语义分割、对象检测、实例分割、人员关键点检测和视频分类。图像分类:语义分割: 对象检测、实例分割和人员关键点检测:&n... 查看详情

在ubuntu下将pytorch模型部署到c++(环境搭建)(代码片段)

...装opencv安装及编译本文总结记录了在将lightweight-openpose的pytorch网络模型部署到c++的过程以及遇到的一些问题,如有错误敬请指正vscodec++环境部署cmakevscode中c++编译过程主 查看详情

[pytorch系列-63]:生成对抗网络gan-图像生成开源项目pytorch-cyclegan-and-pix2pix-代码总体架构与总体学习思路

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121940011目录第1章理论概述1.1 普通GAN,pix2pix,CycleGAN和pix2pixHD的演变过程... 查看详情

mae实现及预训练可视化(cifar-pytorch)(代码片段)

MAE实现及预训练可视化(CIFAR-Pytorch)文章目录MAE实现及预训练可视化(CIFAR-Pytorch)灵感来源自监督学习自监督的发展MAE(MaskedAutoencoders)方法介绍MAE流程图搭建MAE模型MAE组网MAE预训练(pretrain)EncoderDecoder总结测... 查看详情

mae实现及预训练可视化(cifar-pytorch)(代码片段)

MAE实现及预训练可视化(CIFAR-Pytorch)文章目录MAE实现及预训练可视化(CIFAR-Pytorch)灵感来源自监督学习自监督的发展MAE(MaskedAutoencoders)方法介绍MAE流程图搭建MAE模型MAE组网MAE预训练(pretrain)EncoderDecoder总结测... 查看详情

mae实现及预训练可视化(cifar-pytorch)(代码片段)

MAE实现及预训练可视化(CIFAR-Pytorch)文章目录MAE实现及预训练可视化(CIFAR-Pytorch)灵感来源自监督学习自监督的发展MAE(MaskedAutoencoders)方法介绍MAE流程图搭建MAE模型MAE组网MAE预训练(pretrain)EncoderDecoder总结测... 查看详情

pytorch笔记-diffusionmodel源码开发(代码片段)

DiffusionModel的效果如下:源码如下:选择一个数据集确定超参数的值确定扩散过程任意时刻的采样值演示原始数据分布加噪100步后的效果编写拟合逆扩散过程高斯分布的模型编写训练的误差函数编写逆扩散采样函数(inference过程)... 查看详情

[pytorch系列-67]:生成对抗网络gan-图像生成开源项目pytorch-cyclegan-and-pix2pix-使用预训练模型进行测试pix2pix模型(代码片段)

...HiWangWenBing/article/details/122030093目录第1步:下载或克隆pytorch-CycleGAN-and-pix2pix所有代码第2步:切换当前目录第3步:安装依赖文件:第4步:下载pix2pix数据集第5步:下载预训练模型第6步:训练模型第7步... 查看详情