万物皆可gan生成对抗网络生成手写数字part2(代码片段)

我是小白呀 我是小白呀     2022-12-06     818

关键词:

【万物皆可 GAN】生成对抗网络生成手写数字 Part 2

概述

GAN (Generative Adversarial Network) 即生成对抗网络. GAN 网络包括一个生成器 (Generator) 和一个判别器 (Discriminator). GAN 可以自动提取特征, 并判断和优化.

在这里插入图片描述

完整代码

模型

model.py:

import numpy as np
import torch.nn as nn


class Generator(nn.Module):
    """生成器"""

    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            """
            block
            :param in_feat: 输入的特征维度
            :param out_feat: 输出的特征维度
            :param normalize: 归一化
            :return: block
            """
            layers = [nn.Linear(in_feat, out_feat)]

            # 归一化
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))

            # 激活
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            # [b, 100] => [b, 128]
            *block(latent_dim, 128, normalize=False),
            # [b, 128] => [b, 256]
            *block(128, 256),
            # [b, 256] => [b, 512]
            *block(256, 512),
            # [b, 512] => [b, 1024]
            *block(512, 1024),
            # [b, 1024] => [b, 28 * 28 * 1] => [b, 784]
            nn.Linear(1024, int(np.prod(img_shape))),
            # 激活
            nn.Tanh()
        )

    def forward(self, z, img_shape):
        # [b, 100] => [b, 784]
        img = self.model(z)
        # [b, 784] => [b, 1, 28, 28]
        img = img.view(img.size(0), *img_shape)

        # 返回生成的图片
        return img


class Discriminator(nn.Module):
    """判断器"""

    def __init__(self, img_shape):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            # [b, 1, 28, 28] => [b, 784] => [b, 512]
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        # 压平
        img_flat = img.view(img.size(0), -1)

        validity = self.model(img_flat)

        return validity

主函数

main.py

import argparse
import time
import os
import numpy as np

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torchvision.transforms as transforms
from torchvision.utils import save_image

from model import Generator, Discriminator


def get_data(img_size, batch_size):
    """获取数据"""

    dataloader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "./data/mnist",
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
    )

    return dataloader


# ---------- Training ----------
def train_model(n_epochs, dataloader, image_shape, image_path):
    """
    训练模型
    :param n_epochs: 迭代次数
    :param dataloader: 数据集
    :param image_shape: 图片形状
    :param image_path: 图片保存路径
    :return:
    """

    # 迭代n_epochs次
    for epoch in range(n_epochs):
        for i, (imgs, _) in enumerate(dataloader):

            # Adversarial ground truths
            # 生成[b, 1]形状的全1数组
            valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
            # 生成[b, 1]形状的全0数组
            fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

            # Configure input
            real_imgs = Variable(imgs.type(Tensor))

            # -----------------
            #  Train Generator: 训练生成器
            # -----------------

            optimizer_G.zero_grad()

            # 生成噪声, 作为输入 [b, 100]
            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

            # 生成图像 [b, 1, 28, 28]
            gen_imgs = generator(z, image_shape)

            # Loss measures generator's ability to fool the discriminator
            # 计算生成器损失
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)

            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator: 训练判别器
            # ---------------------

            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
            )

            batches_done = epoch * len(dataloader) + i
            if batches_done % opt.sample_interval == 0:
                save_image(gen_imgs.data[:25],
                           "images//.png".format(image_path, batches_done), nrow=5,
                           normalize=True)


def option():
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=16384, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
    parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
    parser.add_argument("--channels", type=int, default=1, help="number of image channels")
    parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
    opt = parser.parse_args()
    return opt


if __name__ == "__main__":
    # 图片保存路径
    image_path = time.strftime("%Y_%m_%d_%H_%M_%S")
    os.makedirs("images", exist_ok=True)
    os.makedirs("images/".format(image_path))

    # 超参数
    opt = option()
    print(opt)

    img_shape = (opt.channels, opt.img_size, opt.img_size)

    cuda = True if torch.cuda.is_available() else False

    # Loss function
    adversarial_loss = torch.nn.BCELoss()

    # 生成器, 判别器初始化
    generator = Generator(latent_dim=opt.latent_dim, img_shape=img_shape)
    discriminator = Discriminator(img_shape=img_shape)

    if cuda:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    dataloader = get_data(
        img_size=opt.img_size,
        batch_size=opt.batch_size
    )

    train_model(
        n_epochs=opt.n_epochs,
        dataloader=dataloader,
        image_shape=img_shape,
        image_path=image_path
    )

输出结果

Namespace(b1=0.5, b2=0.999, batch_size=30000, channels=1, img_size=28, latent_dim=100, lr=0.0002, n_cpu=16, n_epochs=300, sample_interval=20)
[Epoch 0/300] [Batch 0/2] [D loss: 0.713072] [G loss: 0.732759]
[Epoch 0/300] [Batch 1/2] [D loss: 0.625757] [G loss: 0.729401]
[Epoch 1/300] [Batch 0/2] [D loss: 0.557747] [G loss: 0.726514]
[Epoch 1/300] [Batch 1/2] [D loss: 0.501802] [G loss: 0.723442]
[Epoch 2/300] [Batch 0/2] [D loss: 0.456023] [G loss: 0.719633]
[Epoch 2/300] [Batch 1/2] [D loss: 0.421082] [G loss: 0.714690]
[Epoch 3/300] [Batch 0/2] [D loss: 0.397120] [G loss: 0.708184]
[Epoch 3/300] [Batch 1/2] [D loss: 0.382698] [G loss: 0.699813]
[Epoch 4/300] [Batch 0/2] [D loss: 0.376455] [G loss: 0.689178]
[Epoch 4/300] [Batch 1/2] [D loss: 0.375903] [G loss: 0.675870]
[Epoch 5/300] [Batch 0/2] [D loss: 0.379960] [G loss: 0.660027]
[Epoch 5/300] [Batch 1/2] [D loss: 0.387069] [G loss: 0.642543]
[Epoch 6/300] [Batch 0/2] [D loss: 0.396221] [G loss: 0.624336]
[Epoch 6/300] [Batch 1/2] [D loss: 0.406274] [G loss: 0.606929]
[Epoch 7/300] [Batch 0/2] [D loss: 0.416247] [G loss: 0.591856]
[Epoch 7/300] [Batch 1/2] [D loss: 0.424656] [G loss: 0.581296]
[Epoch 8/300] [Batch 0/2] [D loss: 0.431091] [G loss: 0.575895]
[Epoch 8/300] [Batch 1/2] [D loss: 0.435463] [G loss: 0.576295]
[Epoch 9/300] [Batch 0/2] [D loss: 0.438092] [G loss: 0.582505]
[Epoch 9/300] [Batch 1/2] [D loss: 0.439382] [G loss: 0.593778]
[Epoch 10/300] [Batch 0/2] [D loss: 0.440484] [G loss: 0.607302]
[Epoch 10/300] [Batch 1/2] [D loss: 0.441629] [G loss: 0.619709]
[Epoch 11/300] [Batch 0/2] [D loss: 0.444755] [G loss: 0.625674]
[Epoch 11/300] [Batch 1/2] [D loss: 0.451140] [G loss: 0.625349]
[Epoch 12/300] [Batch 0/2] [D loss: 0.461600] [G loss: 0.623004]
[Epoch 12/300] [Batch 1/2] [D loss: 0.475907] [G loss: 0.616745]
[Epoch 13/300] [Batch 0/2] [D loss: 0.492632] [G loss: 0.603465]
[Epoch 13/300] [Batch 1/2] [D loss: 0.509568] [G loss: 0.589097]
[Epoch 14/300] [Batch 0/2] [D loss: 0.524420] [G loss: 0.579904]
[Epoch 14/300] [Batch 1/2] [D loss: 0.532494] [G loss: 0.580591]
[Epoch 15/300] [Batch 0/2] [D loss: 0.533520] [G loss: 0.582400]
[Epoch 15/300] [Batch 1/2] [D loss: 0.528302] [G loss: 0.605238]
[Epoch 16/300] [Batch 0/2] [D loss: 0.520255] [G loss: 0.617953]
[Epoch 16/300] [Batch 1/2] [D loss: 0.512036] [G loss: 0.637511]
[Epoch 17/300] [Batch 0/2] [D loss: 0.506406] [G loss: 0.650202]
[Epoch 17/300] [Batch 1/2] [D loss: 0.507003] [G loss: 0.660099]
[Epoch 18/300] [Batch 0/2] [D loss: 0.515254] [G loss: 0.652070]
[Epoch 18/300] [Batch 1/2] [D loss: 0.531098] [G loss: 0.644829]
[Epoch 19/300] [Batch 0/2] [D loss: 0.546756] [G loss: 0.635194]
[Epoch 19/300] [Batch 1/2] [D loss: 0.557034] [G loss: 0.625947]
[Epoch 20/300] [Batch 0/2] [D loss: 0.562781] [G loss: 0.631091]


... ...

[Epoch 280/300] [Batch 1/2] [D loss: 0.593827] [G loss: 0.486869]
[Epoch 281/300] [Batch 0/2] [D loss: 0.488985] [G loss: 1.372436]
[Epoch 281/300] [Batch 1/2] [D loss: 0.478138] [G loss: 0.782779]
[Epoch 282/300] [Batch 0/2] [D loss: 0.445211] [G loss: 1.119433]
[Epoch 282/300] [Batch 1/2] [D loss: 0.450222] [G loss: 0.956143]
[Epoch 283/300] [Batch 0/2] [D loss: 0.459510] [G loss: 1.030325]
[Epoch 283/300] [Batch 1/2] [D loss: 0.481445] [G loss: 0.955175]
[Epoch 284/300] [Batch 0/2] [D loss: 0.494930] [G loss: 0.967074]
[Epoch 284/300] [Batch 1/2] [D loss: 0.513425] [G loss: 0.888255]
[Epoch 285/300] [Batch 0/2] [D loss: 0.523844] [G loss: 0.933630]
[Epoch 285/300] [Batch 1/2] [D loss: 0.528002] [G loss: 0.829575]
[Epoch 286/300] [Batch 0/2] [D loss: 0.510690] [G loss: 1.025084]
[Epoch 286/300] [Batch 1/2] [D loss: 0.511791] [G loss: 0.752756]
[Epoch 287/300] [Batch 0/2] [D loss: 0.491313] [G loss: 1.258785]
[Epoch 287/300] [Batch 1/2] [D loss: 0.539705] [G loss: 0.594626]
[Epoch 288/300] [Batch 0/2] [D loss: 0.523798] [G loss: 1.575793]
[Epoch 288/300] [Batch 1/2] [D loss: 0.595133] [G loss: 0.471295]
[Epoch 289/300] [Batch 0/2] [D loss: 0.477448] [G loss: 1.566631]
[Epoch 289/300] [Batch 1/2] [D loss: 0.477544] [G loss: 0.706210]
[Epoch 290/300] [Batch 0/2] [D loss: 0.429852] [G loss: 1.241274]
[Epoch 290/300] [Batch 1/2] [D loss: 0.437449] [G loss: 0.915428]
[Epoch 291/300] [Batch 0/2] [D loss: 0.434945] [G loss: 1.074615]
[Epoch 291/300] [Batch 1/2] [D loss: 0.449143] [G loss: 0.958715]
[Epoch 292/300] [Batch 0/2] [D loss: 0.454942] [G loss: 1.042261]
[Epoch 292/300] [Batch 1/2] [D loss: 0.469878] [G loss: 0.926095]
[Epoch 293/300] [Batch 0/2] [D loss: 0.466286] [G loss: 1.058231]
[Epoch 293/300] [Batch 1/2] [D loss: 0.471976] [G loss: 0.851719]
[Epoch 294/300] [Batch 0/2] [D loss: 0.459671] [G loss: 1.190166]
[Epoch 294/300] [Batch 1/2] [D loss: 0.488745] [G loss: 0.692813]
[Epoch 295/300] [Batch 0/2] [D loss: 0.478135] [G loss: 1.635507]
[Epoch 295/300] [Batch 1/2] [D loss: 0.606305] [G loss: 0.428279]
[Epoch 296/300] [Batch 0/2] [D loss: 0.525440] [G loss: 2.067884]
[Epoch 296/300] [Batch 1/2] [D loss: 0.568282] [G loss: 0.477078]
[Epoch 297/300] [Batch 0/2] [D loss: 0.402702] [G loss: 1.508152]
[Epoch 297/300] [Batch 1/2] [D loss: 0.412127] [G loss: 0.929811]
[Epoch 298/300] [Batch 0/2] [D loss: 0.439405] [G loss: 0.930188]
[Epoch 298/300] [Batch 1/2] [D loss: 0.471616] [G loss: 1.025511]
[Epoch 299/300] [Batch 0/2] [D loss: 0.524946] [G loss: 0.700306]
[Epoch 299/300] [Batch 1/2] [D loss: 0.537755] [G loss: 1.130932]

生成的图片

0.png:
在这里插入图片描述
100.png
在这里插入图片描述
200.png
在这里插入图片描述
300.png
在这里插入图片描述

400.png
在这里插入图片描述

500.png
在这里插入图片描述
580.png:
在这里插入图片描述

gan-生成对抗网络-生成手写数字(基于pytorch)(代码片段)

...erativeAdversarialNetwork),网络也如他的名字一样,有生成,有对抗,两个网络相互博弈。我们给两个网络起个名字,第一个网络用来生成数据命名为生成器(generator),另一个网络用来鉴别生成器生成... 查看详情

深度学习100例-生成对抗网络(gan)手写数字生成|第18天(代码片段)

...章目录一、前期工作1.设置GPU2.定义训练参数二、什么是生成对抗网络1.简单介绍2.应用领域三、网络结构四、构建生成器五、构建鉴别器六、训练模型1.保存样例图片2.训练模型七、生成动图本文将采用GAN模型实现手写数字的 查看详情

深度学习100例-生成对抗网络(gan)手写数字生成|第18天(代码片段)

...章目录一、前期工作1.设置GPU2.定义训练参数二、什么是生成对抗网络1.简单介绍2.应用领域三、网络结构四、构建生成器五、构建鉴别器六、训练模型1.保存样例图片2.训练模型七、生成动图本文将采用GAN模型实现手写数字的 查看详情

对抗生成网络gan系列——gan原理及手写数字生成小案例(代码片段)

 🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题🍊往期回顾:目标检测系列——开山之作RCNN原理详解  目标检测系列——FastR-CNN原理详解  目标检测系列——FasterR-CNN原理详解🍊近期目标&... 查看详情

万物皆可gan给马儿换皮肤(代码片段)

【万物皆可GAN】给马儿换皮肤概述真假斑马实现流程代码执行流程执行结果概述CycleGAN(CycleGenerativeAdversarialNetwork)即循环对抗生成网络.CycleGAN可以帮助我们实现图像的互相转换.真假斑马我们先来看一组图片,大家来猜一猜图上的动... 查看详情

对抗生成网络gan系列——ganomaly原理及源码解析

...力于用最通俗的语言描述问题🍊往期回顾:对抗生成网络GAN系列——GAN原理及手写数字生成小案例  对抗生成网络GAN系列——DCGAN简介及人脸图像生成案例  对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战  对... 查看详情

万物皆可gancyclegan原理详解

【万物皆可GAN】CycleGAN原理详解概述CycleGAN可以做什么图片转换图片修复换脸CycleGAN网络结构CycleGAN损失函数概述CycleGAN(CycleGenerativeAdversarialNetwork)即循环对抗生成网络.CycleGAN可以帮助我们实现图像的互相转换.CycleGAN不需要数据配对... 查看详情

gan生成对抗网络----手写数据实现(代码片段)

目录GAN------``以假乱真``训练流程环境数据集完整代码结果展示GAN------以假乱真GAN的基本理念其实非常简单,其核心由两个目标互相冲突的神经网络组成,这两个网络会以越来越复杂的方法来“蒙骗”对方。这... 查看详情

gan生成对抗网络----手写数据实现(代码片段)

目录GAN------``以假乱真``训练流程环境数据集完整代码结果展示GAN------以假乱真GAN的基本理念其实非常简单,其核心由两个目标互相冲突的神经网络组成,这两个网络会以越来越复杂的方法来“蒙骗”对方。这... 查看详情

[pytorch系列-61]:生成对抗网络gan-基本原理-自动生成手写数字案例分析(代码片段)

....4自定义数据集标签第4章定义网络4.1定义网络参数4.2定义生成网络4.3定义判决网络第5章模型训练5.1定义训练参数5.2 定义loss5.3训练网络第6章模型预测与评估6.1显示g网络loss6.2显示D网络loss第1章基本原理https://blog.csdn.net/HiWangWenBing... 查看详情

美团云tensorflow生成对抗网络(generativeadversarialnetworks)实战案例(代码片段)

生成对抗网络GAN概述GAN的前世今生GAN的基本原理GAN的代码实现在DLS运行的注意事项数据文件的读取数据文件的写入GAN示例代码概述本文主要介绍GAN的基本知识,以及在DLS上运行的注意事项。本模块继续通过经典的MNIST数据集... 查看详情

手把手写深度学习:用gans生成手写数字(代码片段)

...。这一讲从DCGAN的原理出发,一步步深入,用DCGAN生成手写数字。目录GANs基本原理算法流程DCGANs的改进模型结构反卷积模型参数训练过程开始训练!训练代码loss不收敛?正常现象 结果生成参考:GANs基本原理GA... 查看详情

手把手写深度学习:用gans生成手写数字(代码片段)

...。这一讲从DCGAN的原理出发,一步步深入,用DCGAN生成手写数字。目录GANs基本原理算法流程DCGANs的改进模型结构反卷积模型参数训练过程开始训练!训练代码loss不收敛?正常现象 结果生成参考:GANs基本原理GA... 查看详情

keras深度学习实战(22)——生成对抗网络详解与实现(代码片段)

Keras深度学习实战(22)——生成对抗网络详解与实现0.前言1.生成对抗网络原理2.模型分析3.利用生成对抗网络生成手写数字图像小结系列链接0.前言生成对抗网络(GenerativeAdversarialNetworks,GAN)使用神经网络生成与原始图像集非常相... 查看详情

对抗生成网络gan系列——cyclegan简介及图片春冬变换案例(代码片段)

...力于用最通俗的语言描述问题🍊往期回顾:对抗生成网络GAN系列——GAN原理及手写数字生成小案例  对抗生成网络GAN系列——DCGAN简介及人脸图像生成案例🍊近期目标:写好专栏的每一篇文章🍊支持小苏&#x... 查看详情

keras深度学习实战(22)——生成对抗网络详解与实现(代码片段)

Keras深度学习实战(22)——生成对抗网络详解与实现0.前言1.生成对抗网络原理2.模型分析3.利用生成对抗网络生成手写数字图像小结系列链接0.前言生成对抗网络(GenerativeAdversarialNetworks,GAN)使用神经网络生成与原始图像集... 查看详情

生成对抗网络gan

详解一:GAN完整理论推导和实现详解二:详解生成对抗网络(GAN)原理 查看详情

深度卷积生成对抗网络dcgan——生成手写数字图片(代码片段)

前言本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用KerasAPI与tf.GradientTape编写的,其中tf.GradientTrape是训练模型时用到的。 本文用到imageio库来生成gif图片,如果没有安装的,需要安装... 查看详情