小白学习pytorch教程十一基于mnist数据集训练第一个生成性对抗网络(代码片段)

刘润森! 刘润森!     2022-12-11     499

关键词:

@Author:Runsen

GAN 是使用两个神经网络模型训练的生成模型。一种模型称为生成网络模型,它学习生成新的似是而非的样本。另一个模型被称为判别网络,它学习区分生成的例子和真实的例子。

生成性对抗网络

2014,蒙特利尔大学的Ian Goodfellow和他的朋友发明了生成性对抗网络(GAN)。自它出版以来,有许多它的变体和客观功能来解决它的问题

论文在这里找到.

论文提出了两种模型:生成模型和判别模型。两个模型竞争,以产生真实和假的样本。2016年,Yann LeCun将GANs描述为“过去二十年机器学习中最酷的想法”。

GAN 的大部分研究和应用都集中在计算机视觉领域。

其原因是卷积神经网络 (CNN) 等深度学习模型在过去 5 到 7 年中在计算机视觉领域取得了巨大成功,例如在具有挑战性的任务(如对象检测和人脸识别。

GAN 的典型例子是生成新的逼真的照片,最令人吃惊的是生成照片般逼真的人脸的例子。

在本教程中,我们将实现一个简单的GAN生成假的MNIST样本。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as utils

import numpy as np
import matplotlib.pyplot as plt
# CPU / GPU Setting
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)  #cuda

使用MNIST数据集,具有最小大小的数据集。

它由60000个训练图像和10000个测试图像组成,每个图像有28*28的大小和一个彩色通道。

# Define a transform 
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = (0.5, ), std = (0.5, ))
])

# batch_size是一个前向和后向传播过程中的图像数。
batch_size = 100

mnist = datasets.MNIST('./data/MNIST', 
                       download = True, 
                       train = True, 
                       transform = transform)

mnist_loader = DataLoader(dataset = mnist, 
                          batch_size = batch_size, 
                          shuffle = True)
# CPU
def imshow(img, title):
    img = utils.make_grid(img.cpu().detach())
    img = (img+1)/2
    npimg = img.detach().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.title(title)
    plt.show()
#GPU
def imshow(img, title):
    npimg = img.detach().numpy()
    fig = plt.figure(figsize = (10, 10))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.title(title)
    plt.show()

images, labels = iter(mnist_loader).next()
imshow(images[0:16, :, :], "MNIST Images")

建立一个GANs模型。一个Generator和Discriminator

GANs由完全连接的层组成。它将从100维高斯分布采样的噪声转换为MNIST图像。鉴别器网络也由完全连接的层组成,用于区分输入数据是真是假。

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        latent_size = 100
        output = 28*28
        
        self.main = nn.Sequential(
            nn.Linear(latent_size, 128),
            nn.ReLU(inplace=True),
            
            nn.Linear(128, 256),
            nn.ReLU(inplace=True),
            
            nn.Linear(256, 512),
            nn.ReLU(inplace=True),
            
            nn.Linear(512, output),
            nn.Tanh()
        )
        
    def forward(self, x):
        out = self.main(x)
        out = out.view(-1, 1, 28, 28)
        return out


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        n_features = 28 * 28
        n_out = 1
        
        self.main = nn.Sequential(
            nn.Linear(n_features, 512),
            nn.ReLU(inplace=True),
            
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            
            nn.Linear(64, n_out),
            nn.Sigmoid()        
        )
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        out = self.main(x)
        return out

G = Generator().to(device)
D = Discriminator().to(device)

生成性对抗网络训练过程的损失函数是二进制交叉熵损失,由torch.nn.BCELoss实现。

这两种模型都使用torch.optim.Adam作为优化工具,学习率设置为0.002。

# Objective Function
criterion = nn.BCELoss()

# Optimizer
G_optimizer = optim.Adam(G.parameters(), lr = 0.0002)
D_optimizer = optim.Adam(D.parameters(), lr = 0.0002)

# Constants
noise_dim = 100
num_epochs = 50
total_batch = len(mnist_loader)

# Lists
G_losses = []
D_losses = []

# Noise
sample_size = 16
fixed_noise = torch.randn(sample_size, noise_dim).to(device)

# Train
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(mnist_loader):
        
        # Images #
        images = images.reshape(batch_size, -1).float().to(device)
        
        # Labels #
        ones = torch.ones(batch_size, 1).to(device)
        zeros = torch.zeros(batch_size, 1).to(device)
        
        # Noise #
        noise = torch.randn(batch_size, noise_dim).to(device)
        
        # Initialize Optimizers
        D_optimizer.zero_grad()
        G_optimizer.zero_grad()
        
        #######################
        # Train Discriminator #
        #######################
        
        # Forward Images #
        prob_real = D(images)
        D_real_loss = criterion(prob_real, ones)
        
        # Generate Samples #
        fake_images = G(noise)
        prob_fake = D(fake_images)
        
        # Forward Fake Samples and Calculate Discriminator Loss #
        D_fake_loss = criterion(prob_fake, zeros)
        D_loss = (D_real_loss + D_fake_loss).mean()
        
        # Back Propagation and Update
        D_loss.backward()
        D_optimizer.step()
        
        ###################
        # Train Generator #
        ###################
        
        fake_images = G(noise)
        prob_fake = D(fake_images)
        
        # According to the section 3 in paper,
        # early in learning, when G is very poor, D can reject samples from G.
        # In this case, log(1-D(G(z))) saturates. 
        # thus, train G to maximiaze log(D(G(z))) instead of minimizing log(1-D(G(z)))
        G_loss = criterion(prob_fake, ones)
        
        # Back Propagation and Update
        G_loss.backward()
        G_optimizer.step()
        
        # Save Losses for Plotting Later
        G_losses.append(G_loss.item())
        D_losses.append(D_loss.item())
        
        # Print Statistics #
        if (i + 1) % 100 == 0:
            print("Epoch [%d/%d] Iter [%d/%d], D_Loss: %.4f G_Loss: %.4f"
                  %(epoch+1, num_epochs, i+1, total_batch, D_loss.item(), G_loss.item()))
    
    # Generate Samples #
    if epoch % 1 == 0:
        fake_samples = G(fixed_noise)
        imshow(fake_samples, "Generated MNIST Images")
    
# Save Model Weights for Digit Generation
torch.save(G.state_dict(), './data/GAN.pkl')

plt.figure(figsize = (8, 6))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="Generator")
plt.plot(D_losses, label="Discriminator")
plt.xlabel("Iterations")
plt.ylabel("Losses")
plt.legend()
plt.show()

sample_size = 64
noise_dim = 100

noise = torch.randn(sample_size, noise_dim).to(device)

G.load_state_dict(torch.load('GAN.pkl'))
fake_samples = G(fixed_noise)
imshow(fake_samples, "Generated MNIST Images")

GAN生成性对抗网络的运用

  • 将语义图像翻译成城市景观和建筑物的照片。
  • 将卫星照片翻译成地图。
  • 从白天到晚上的照片翻译。
  • 将黑白照片翻译成彩色。


- 论文在这里找到.

- 上述代码的论文.

- 上述代码.

小白学习pytorch教程七基于乳腺癌数据集​​构建logistic二分类模型(代码片段)

...、文字分类都属于这一类。在这篇博客中,将学习如何在PyTorch中实现逻辑回归。文章目录1.数据集加载2.预处理3.模型搭建4.训练和优化1.数据集加载在这里,我将使用来自sklearn库的乳腺癌数据集。这是一个简单的二元类分类数据... 查看详情

小白学习pytorch教程十基于大型电影评论数据集训练第一个lstm模型(代码片段)

@Author:Runsen文章目录编码建立字典并对评论进行编码对标签进行编码删除异常值填充序列数据加载器RNN模型的实现训练本博客对原始IMDB数据集进行预处理,建立一个简单的深层神经网络模型,对给定数据进行情感分析... 查看详情

小白学习pytorch教程十七pytorch中数据集torchvision和torchtext(代码片段)

@Author:Runsen对于PyTorch加载和处理不同类型数据,官方提供了torchvision和torchtext。之前使用torchDataLoader类直接加载图像并将其转换为张量。现在结合torchvision和torchtext介绍torch中的内置数据集Torchvision中的数据集MNISTMNIST... 查看详情

小白学习pytorch教程八使用图像数据增强手段,提升cifar-10数据集精确度(代码片段)

@Author:Runsen上次基于CIFAR-10数据集,使用PyTorch​​构建图像分类模型的精确度是60%,对于如何提升精确度,方法就是常见的transforms图像数据增强手段。importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.datai 查看详情

pytorch基于cnn的手写数字识别(在mnist数据集上训练)(代码片段)

最终成果http://pytorch-cnn-mnist.herokuapp.com/GITHUBhttps://github.com/XavierJiezou/pytorch-cnn-mnist本文以最经典的mnist数据集为例,讲述了使用pytorch做机器学习的一整套流程,文中所提到的所有代码都可以到github中查看。项目场景简单的... 查看详情

小白学习pytorch教程十七pytorch中数据集torchvision和torchtext(代码片段)

@Author:Runsen对于PyTorch加载和处理不同类型数据,官方提供了torchvision和torchtext。之前使用torchDataLoader类直接加载图像并将其转换为张量。现在结合torchvision和torchtext介绍torch中的内置数据集Torchvision中的数据集MNISTMNIST... 查看详情

[基于pytorch的mnist识别02]用户数据集的读取(代码片段)

写在前面pytorch包含了很多包括mnist在内的开源数据集,但是如果要建立自己的神经网络的话肯定需要训练自己的数据集,那么如何利用pytorch加载用户自己的数据集呢?今天就来解决这个问题。今天的工作需要加载用... 查看详情

基于pytorch平台实现对mnist数据集的分类分析(前馈神经网络softmax)基础版(代码片段)

基于pytorch平台实现对MNIST数据集的分类分析(前馈神经网络、softmax)基础版文章目录基于pytorch平台实现对MNIST数据集的分类分析(前馈神经网络、softmax)基础版前言一、基于“前馈神经网络”模型,分类分析... 查看详情

小白学习pytorch教程九基于pytorch训练第一个rnn模型(代码片段)

@Author:Runsen当阅读一篇课文时,我们可以根据前面的单词来理解每个单词的,而不是从零开始理解每个单词。这可以称为记忆。卷积神经网络模型(CNN)不能实现这种记忆,因此引入了递归神经网络模型(RNN)来解决这一问题... 查看详情

pytorch土堆pytorch教程学习torchvision中的数据集的使用(代码片段)

...据集。内置的数据集有CIFAR10、MNIST、COCO等,更多可进入pytorch官网查看。所有内置的数据集都继承了torch.utils.data.Dataset类,并且实现了__getitem__和__len__。所有的数据集几乎都有相似的API。下面以CIFAR10数据集的使用为例来认识下内... 查看详情

小白学习keras教程二基于cifar-10数据集训练简单的mlp分类模型(代码片段)

@Author:Runsen分类任务的MLP当目标(y)是离散的(分类的)对于损失函数,使用交叉熵;对于评估指标,通常使用accuracy数据集描述CIFAR-10数据集包含10个类中的60000个图像—50000个用于培训,10000个用于测试有关更多信息,请参阅... 查看详情

小白学习kears教程四keras基于数字数据集建立基础的cnn模型

@Author:Runsen文章目录基本卷积神经网络(CNN)加载数据集1.创建模型2.卷积层3.激活层4.池化层5.Dense(全连接层)6.Modelcompile&train基本卷积神经网络(CNN)-CNN的基本结构:CNN与MLP相似,因为它们只向前传送信号(前馈网络),... 查看详情

pytorch实现mlp并在mnist数据集上验证(代码片段)

...找一份源码以参考,重点在于照着源码手敲一遍,以熟悉pytorch的基本操作。实验要求熟悉pytorch的基本操作:用pytorch实现MLP,并在MNIST数据集上进行训练环境配置实验环境如下:Win10python3.8Anaconda3Cuda10.2+cudnnv7GPU:NVIDIAGeForceMX250配置... 查看详情

小白学习keras教程五基于reuters数据集训练不同rnn循环神经网络模型

@Author:Runsen文章目录循环神经网络RNNLoadDataset1.VanillaRNN2.StackedVanillaRNN3.LSTM4.StackedLSTM循环神经网络RNN前馈神经网络(例如MLP和CNN)功能强大,但它们并未针对处理“顺序”数据进行优化换句话说,它们不具有先前输入的“记忆”... 查看详情

用pytorch对以mnist数据集进行卷积神经网络(代码片段)

PyTorch构建深度学习网络一般步骤加载数据集\\colorred加载数据集加载数据集;定义网络结构模型\\colorred定义网络结构模型定义网络结构模型;定义损失函数\\colorred定义损失函数定义损失函数;定义优化算法\\colorred定... 查看详情

小白学习keras教程七基于digits数据集训练基本自动编码器无监督神经网络

@Author:Runsen本文博客目标:了解自动编码器的基本知识参考文献https://blog.keras.io/building-autoencoders-in-keras.htmlhttps://medium.com/@curiousily/credit-card-fraud-detection-using-autoencoders-in-keras-tensorflow-for-hack 查看详情

pytorch学习(代码片段)

Pytorch1_手写字识别数据集的载入数据集展示网络搭建模型训练训练损失函数图模型预测下面将构造一个三层全连接层的神经网络来对手写数字进行识别。数据集的载入importtorchfromtorchimportnnfromtorch.nnimportfunctionalasFfromtorchimportoptimim... 查看详情

小白学习keras教程六基于cifar-10数据集训练cnn-rnn神经网络模型(代码片段)

@Author:Runsen文章目录LoadDataset1.CNN-RNN2.CNN-RNN-2LoadDatasetCIFAR-10datasetimportnumpyasnpfromsklearn.metricsimportaccuracy_scorefromtensorflow.keras.datasetsimportcifar10fromtensorflow.keras.utilsimp 查看详情