深度学习100例-生成对抗网络(gan)手写数字生成|第18天(代码片段)

K同学啊 K同学啊     2022-12-11     236

关键词:

🔱 大家好,我是 👉 K同学啊《深度学习100例》系列将持续更新 欢迎 点赞👍、收藏⭐、关注👀

本文将采用 GAN 模型实现手写数字的生成,重点是了解 GAN 模型的结构及其搭建方法。

一、前期工作

🚀 我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

🚀 深度学习新人必看:

  1. 小白入门深度学习 | 第一篇:配置深度学习环境
  2. 小白入门深度学习 | 第二篇:编译器的使用-Jupyter Notebook

🚀 往期精彩-卷积神经网络篇:

  1. 深度学习100例-卷积神经网络(CNN)实现mnist手写数字识别 | 第1天
  2. 深度学习100例-卷积神经网络(CNN)彩色图片分类 | 第2天
  3. 深度学习100例-卷积神经网络(CNN)服装图像分类 | 第3天
  4. 深度学习100例-卷积神经网络(CNN)花朵识别 | 第4天
  5. 深度学习100例-卷积神经网络(CNN)天气识别 | 第5天
  6. 深度学习100例-卷积神经网络(VGG-16)识别海贼王草帽一伙 | 第6天
  7. 深度学习100例-卷积神经网络(VGG-19)识别灵笼中的人物 | 第7天
  8. 深度学习100例-卷积神经网络(ResNet-50)鸟类识别 | 第8天
  9. 深度学习100例-卷积神经网络(AlexNet)手把手教学 | 第11天
  10. 深度学习100例-卷积神经网络(CNN)识别验证码 | 第12天
  11. 深度学习100例-卷积神经网络(Inception V3)识别手语 | 第13天
  12. 深度学习100例-卷积神经网络(Inception-ResNet-v2)识别交通标志 | 第14天
  13. 深度学习100例-卷积神经网络(CNN)实现车牌识别 | 第15天
  14. 深度学习100例-卷积神经网络(CNN)识别神奇宝贝小智一伙 | 第16天
  15. 深度学习100例-卷积神经网络(CNN)注意力检测 | 第17天

🚀 往期精彩-循环神经网络篇:

  1. 深度学习100例-循环神经网络(RNN)实现股票预测 | 第9天
  2. 深度学习100例-循环神经网络(LSTM)实现股票预测 | 第10天

🚀 本文选自专栏:《深度学习100例》

1. 设置GPU

如果使用的是CPU可以注释掉这部分的代码。

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")
    
# 打印显卡信息,确认GPU可用
print(gpus)
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
from tensorflow.keras import layers, datasets, Sequential, Model, optimizers
from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2D

import matplotlib.pyplot as plt
import numpy             as np
import sys,os,pathlib

2. 定义训练参数

img_shape  = (28, 28, 1)
latent_dim = 200

二、什么是生成对抗网络

1. 简单介绍

生成对抗网络(GAN) 包含生成器和判别器,两个模型通过对抗训练不断学习、进化。

  • 生成器(Generator):生成数据(大部分情况下是图像),目的是“骗过”判别器。

  • 鉴别器(Discriminator):判断这张图像是真实的还是机器生成的,目的是找出生成器生成的“假数据”。

2. 应用领域

GAN 的应用十分广泛,它的应用包括图像合成、风格迁移、照片修复以及照片编辑,数据增强等等。

1)风格迁移

图像风格迁移是将图像A的风格转换到图像B中去,得到新的图像。

2)图像生成

GAN 不但能生成人脸,还能生成其他类型的图片,比如漫画人物。

三、网络结构

简单来讲,就是用生成器生成手写数字图像,用鉴别器鉴别图像的真假。二者相互对抗学习(卷),在对抗学习(卷)的过程中不断完善自己,直至生成器可以生成以假乱真的图片(鉴别器无法判断其真假)。结构图如下:

GAN步骤:

  • 1.生成器(Generator)接收随机数并返回生成图像。
  • 2.将生成的数字图像与实际数据集中的数字图像一起送到鉴别器(Discriminator)。
  • 3.鉴别器(Discriminator)接收真实和假图像并返回概率,0到1之间的数字,1表示真,0表示假。

四、构建生成器

def build_generator():
    # ======================================= #
    #     生成器,输入一串随机数字生成图片
    # ======================================= #
    model = Sequential([
        layers.Dense(256, input_dim=latent_dim),
        layers.LeakyReLU(alpha=0.2),               # 高级一点的激活函数
        layers.BatchNormalization(momentum=0.8),   # BN 归一化
        
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        
        layers.Dense(1024),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        
        layers.Dense(np.prod(img_shape), activation='tanh'),
        layers.Reshape(img_shape)
    ])

    noise = layers.Input(shape=(latent_dim,))
    img = model(noise)

    return Model(noise, img)

五、构建鉴别器

def build_discriminator():
    # ===================================== #
    #   鉴别器,对输入的图片进行判别真假
    # ===================================== #
    model = Sequential([
        layers.Flatten(input_shape=img_shape),
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(256),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(1, activation='sigmoid')
    ])

    img = layers.Input(shape=img_shape)
    validity = model(img)

    return Model(img, validity)

  • 鉴别器训练原理:通过对输入的图片进行鉴别,从而达到提升的效果
  • 生成器训练原理:通过鉴别器对其生成的图片进行鉴别,来实现提升
# 创建判别器
discriminator = build_discriminator()
# 定义优化器
optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator.compile(loss='binary_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])

# 创建生成器 
generator = build_generator()
gan_input = layers.Input(shape=(latent_dim,))
img = generator(gan_input)

# 对生成的假图片进行预测
validity = discriminator(img)
combined = Model(gan_input, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

六、训练模型

1. 保存样例图片

def sample_images(epoch):
    """
    保存样例图片
    """
    row, col = 4, 4
    noise = np.random.normal(0, 1, (row*col, latent_dim))
    gen_imgs = generator.predict(noise)

    fig, axs = plt.subplots(row, col)
    cnt = 0
    for i in range(row):
        for j in range(col):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/%05d.png" % epoch)
    plt.close()

2. 训练模型

train_on_batch:函数接受单批数据,执行反向传播,然后更新模型参数,该批数据的大小可以是任意的,即,它不需要提供明确的批量大小,属于精细化控制训练模型。

def train(epochs, batch_size=128, sample_interval=50):
    # 加载数据
    (train_images,_), (_,_) = tf.keras.datasets.mnist.load_data()

    # 将图片标准化到 [-1, 1] 区间内   
    train_images = (train_images - 127.5) / 127.5
    # 数据
    train_images = np.expand_dims(train_images, axis=3)

    # 创建标签
    true = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    
    # 进行循环训练
    for epoch in range(epochs): 

        # 随机选择 batch_size 张图片
        idx = np.random.randint(0, train_images.shape[0], batch_size)
        imgs = train_images[idx]      
        
        # 生成噪音
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        # 生成器通过噪音生成图片,gen_imgs的shape为:(128, 28, 28, 1)
        gen_imgs = generator.predict(noise)
        
        # 训练鉴别器 
        d_loss_true = discriminator.train_on_batch(imgs, true)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        # 返回loss值
        d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)

        # 训练生成器
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        g_loss = combined.train_on_batch(noise, true)
        
        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

        # 保存样例图片
        if epoch % sample_interval == 0:
            sample_images(epoch)
train(epochs=30000, batch_size=256, sample_interval=200)
0 [D loss: 0.587824, acc.: 67.77%] [G loss: 0.634870]
1 [D loss: 0.387015, acc.: 74.22%] [G loss: 0.541133]
2 [D loss: 0.380705, acc.: 63.67%] [G loss: 0.455188]
3 [D loss: 0.408720, acc.: 56.25%] [G loss: 0.405431]
4 [D loss: 0.445802, acc.: 52.34%] [G loss: 0.343866]
......
176 [D loss: 0.394246, acc.: 66.41%] [G loss: 0.648134]
177 [D loss: 0.393966, acc.: 66.21%] [G loss: 0.640118]
178 [D loss: 0.402815, acc.: 65.62%] [G loss: 0.641665]
179 [D loss: 0.404573, acc.: 65.82%] [G loss: 0.647686]
180 [D loss: 0.394707, acc.: 67.19%] [G loss: 0.631329]
......

七、生成动图

如果报错:ModuleNotFoundError: No module named 'imageio' 可以使用:pip install imageio 安装 imageio 库。

import imageio

def compose_gif():
    # 图片地址
    data_dir = "F:/jupyter notebook/DL-100-days/code/images"
    data_dir = pathlib.Path(data_dir)
    paths    = list(data_dir.glob('*'))
    
    gif_images = []
    for path in paths:
        print(path)
        gif_images.append(imageio.imread(path))
    imageio.mimsave("test.gif",gif_images,fps=2)
    
compose_gif()
F:\\jupyter notebook\\DL-100-days\\code\\images\\00000.png
F:\\jupyter notebook\\DL-100-days\\code\\images\\00200.png
F:\\jupyter notebook\\DL-100-days\\code\\images\\00400.png
F:\\jupyter notebook\\DL-100-days\\code\\images\\00600.png
F:\\jupyter notebook\\DL-100-days\\code\\images\\00800.png
F:\\jupyter notebook\\DL-100-days\\code\\images\\01000.png
.....

图片的生成过程展示(约50s):


深度学习新人必看:

  1. 小白入门深度学习 | 第一篇:配置深度学习环境
  2. 小白入门深度学习 | 第二篇:编译器的使用-Jupyter Notebook

往期精彩

  1. 深度学习100例-卷积神经网络(CNN)实现mnist手写数字识别 | 第1天
  2. 深度学习100例-卷积神经网络(CNN)彩色图片分类 | 第2天
  3. 深度学习100例-卷积神经网络(CNN)服装图像分类 | 第3天
  4. 深度学习100例-卷积神经网络(CNN)花朵识别 | 第4天
  5. 深度学习100例-卷积神经网络(CNN)天气识别 | 第5天
  6. 深度学习100例-卷积神经网络(VGG-16)识别海贼王草帽一伙 | 第6天
  7. 深度学习100例-卷积神经网络(VGG-19)识别灵笼中的人物 | 第7天
  8. 深度学习100例-卷积神经网络(ResNet-50)鸟类识别 | 第8天
  9. 深度学习100例-循环神经网络(RNN)实现股票预测 | 第9天
  10. 深度学习100例-循环神经网络(LSTM)实现股票预测 | 第10天
  11. 深度学习100例-卷积神经网络(AlexNet)手把手教学 | 第11天
  12. 深度学习100例-卷积神经网络(CNN)识别验证码 | 第12天
  13. 深度学习100例-卷积神经网络(Inception V3)识别手语 | 第13天
  14. 深度学习100例-卷积神经网络(Inception-ResNet-v2)识别交通标志 | 第14天
  15. 深度学习100例-卷积神经网络(CNN)实现车牌识别 | 第15天
  16. 深度学习100例-卷积神经网络(CNN)识别神奇宝贝小智一伙 | 第16天
  17. 深度学习100例-卷积神经网络(CNN)注意力检测 | 第17天

🚀 来自专栏:《深度学习100例》

未完~

持续更新 欢迎 点赞👍、收藏⭐、关注👀

  • 点赞👍:点赞给我持续更新的动力
  • 收藏⭐️:收藏后你能够随时找到文章
  • 关注👀:关注我第一时间接收最新文章

深度学习100例-生成对抗网络(dcgan)手写数字生成|第19天(代码片段)

文章目录深度卷积生成对抗网络(DCGAN)一、前言二、什么是生成对抗网络?1.设置GPU2.加载和准备数据集三、创建模型1.生成器2.判别器四、定义损失函数和优化器1.判别器损失2.生成器损失五、定义训练循环六、训练... 查看详情

深度学习100例-生成对抗网络(dcgan)手写数字生成|第19天(代码片段)

文章目录深度卷积生成对抗网络(DCGAN)一、前言二、什么是生成对抗网络?1.设置GPU2.加载和准备数据集三、创建模型1.生成器2.判别器四、定义损失函数和优化器1.判别器损失2.生成器损失五、定义训练循环六、训练... 查看详情

学习手写数字生成(代码片段)

...动地址:CSDN21天学习挑战赛🍨本文为🔗365天深度学习训练营中的学习记录博客🍦参考文章地址:🔗深度学习100例-生成对抗网络(GAN)手写数字生成|第18天🍖作者:K同学啊敲例子今天敲的... 查看详情

学习手写数字生成(代码片段)

...动地址:CSDN21天学习挑战赛🍨本文为🔗365天深度学习训练营中的学习记录博客🍦参考文章地址:🔗深度学习100例-生成对抗网络(GAN)手写数字生成|第18天🍖作者:K同学啊敲例子今天敲的... 查看详情

《深度学习100例》数据和代码(代码片段)

《深度学习100例》分为《深度学习基础50例》与《深度学习进阶50例》,大家可以选择一次性订阅《深度学习100例》也可以分开订阅。《深度学习基础50例》:主要讲解深度学习中的一些基础算法,主要体现在目标识别... 查看详情

keras深度学习实战(22)——生成对抗网络详解与实现(代码片段)

Keras深度学习实战(22)——生成对抗网络详解与实现0.前言1.生成对抗网络原理2.模型分析3.利用生成对抗网络生成手写数字图像小结系列链接0.前言生成对抗网络(GenerativeAdversarialNetworks,GAN)使用神经网络生成与原始图像集... 查看详情

万物皆可gan生成对抗网络生成手写数字part1(代码片段)

【万物皆可GAN】生成对抗网络生成手写数字Part1概述GAN网络结构GAN训练流程模型详解生成器判别器概述GAN(GenerativeAdversarialNetwork)即生成对抗网络.GAN网络包括一个生成器(Generator)和一个判别器(Discriminator).GAN可以自动提取特征,并判... 查看详情

keras深度学习实战(22)——生成对抗网络详解与实现(代码片段)

Keras深度学习实战(22)——生成对抗网络详解与实现0.前言1.生成对抗网络原理2.模型分析3.利用生成对抗网络生成手写数字图像小结系列链接0.前言生成对抗网络(GenerativeAdversarialNetworks,GAN)使用神经网络生成与原始图像集非常相... 查看详情

手把手写深度学习:用gans生成手写数字(代码片段)

前言:2014年GANs在NPIS大会上被提出,但是因为种种原因沉寂了两年,直到DCGANs横空出世,将GAN和CNN完美结合,才真正打开了GANs井喷时代,一下子成为最强风口。好几年过去了,热度有增无减。这一讲从... 查看详情

手把手写深度学习:用gans生成手写数字(代码片段)

前言:2014年GANs在NPIS大会上被提出,但是因为种种原因沉寂了两年,直到DCGANs横空出世,将GAN和CNN完美结合,才真正打开了GANs井喷时代,一下子成为最强风口。好几年过去了,热度有增无减。这一讲从... 查看详情

万物皆可gan生成对抗网络生成手写数字part2(代码片段)

【万物皆可GAN】生成对抗网络生成手写数字Part2概述完整代码模型主函数输出结果生成的图片概述GAN(GenerativeAdversarialNetwork)即生成对抗网络.GAN网络包括一个生成器(Generator)和一个判别器(Discriminator).GAN可以自动提取特征,并判断和... 查看详情

ai生成手写数字+智能卡点切图

我的第一次尝试,采用生成对抗网络(GAN)生成minst手写数字,并实现智能卡点切图。除配乐外其他均由代码自动生成,代码将会在《深度学习100例》后期的文章中展示。 查看详情

深度学习笔记九:生成对抗网络gan(基本理论)

参考:李宏毅深度学习20172017.12.2更新:又有一些理解,简化语言和架构 查看详情

gan-生成对抗网络-生成手写数字(基于pytorch)(代码片段)

...erativeAdversarialNetwork),网络也如他的名字一样,有生成,有对抗,两个网络相互博弈。我们给两个网络起个名字,第一个网络用来生成数据命名为生成器(generator),另一个网络用来鉴别生成器生成... 查看详情

深度学习笔记十:生成对抗网络再思考f-gan

参考:李宏毅深度学习2017 查看详情

深度学习generativeadversarialnetworks,gan生成对抗网络分类(代码片段)

目录1思维导图2大纲3经典GAN简介3.1SGAN3.2ConditionalGAN3.3BidirectionalGAN(BiGAN)3.4InfoGAN3.5AuxiliaryClassifierGAN(AC-GAN)3.6BoundaryEquilibriumGAN(BEGAN)3.7Self-attentionGAN(SAGAN)3.8DeepConvolutionalGAN(DCGAN) 查看详情

《深度学习100例》数据和代码(代码片段)

《深度学习100例》分为《深度学习基础50例》与《深度学习进阶50例》,大家可以选择一次性订阅《深度学习100例》也可以分开订阅。《深度学习基础50例》:主要讲解深度学习中的一些基础算法,主要体现在目标识别,以及循环... 查看详情

对抗生成网络gan系列——gan原理及手写数字生成小案例(代码片段)

 🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题🍊往期回顾:目标检测系列——开山之作RCNN原理详解  目标检测系列——FastR-CNN原理详解  目标检测系列——FasterR-CNN原理详解🍊近期目标&... 查看详情