《Generative Adversarial Nets》是 GAN 系列的鼻祖。在这里通过 PyTorch 实现 GAN ,并且用于手写数字生成。

摘要: 我们提出了一个新的框架,通过对抗处理来评估生成模型。其中,我们同时训练两个 model :一个是生成模型 G,用于获取数据分布;另一个是判别模型 D,用来预测样本来自训练数据而不是生成模型 G 的概率。G 的训练过程是最大化 D 犯错的概率。这个框架对应于一个极小极大的二人游戏。在任意函数 G 和 D 的空间中,存在着一个唯一的解,G 恢复训练数据的分布而 D 一直等于1/2. 在 G 和 D 都由多层感知器定义的情况下,整个系统可以通过反向传播进行训练。  

import time
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
import torch.nn as nn
from torch.utils.data import DataLoader if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True

要导入的包

#########################
## SETTINGS
######################### # Device
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") # Hyperparameters
random_seed = 123
generator_learning_rate = 0.001
discriminator_learning_rate = 0.001
num_epochs = 100
batch_size = 128
LATENT_DIM = 100
IMG_SHAPE = (1, 28, 28)
IMG_SIZE = 1
for x in IMG_SHAPE:
IMG_SIZE *= x

设置超参数

#########################
## MNIST DATASET
######################### train_dataset = datasets.MNIST(root='../data',
train=True,
transform=transforms.ToTensor(),
download=True) test_dataset = datasets.MNIST(root='../data',
train=False,
transform=transforms.ToTensor()) train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True) test_loader = DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False) # Checking the dataset
for images, labels in train_loader:
print('Image batch dimensions:', images.shape)
print('Image label dimensions:', labels.shape)
break # 输出 # Image batch dimensions: torch.Size([128, 1, 28, 28])
# Image label dimensions: torch.Size([128])

加载MNIST数据集

##############################
## MODEL
############################## class GAN(torch.nn.Module): def __init__(self):
super(GAN, self).__init__() self.generator = nn.Sequential(
nn.Linear(LATENT_DIM, 128),
nn.LeakyReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Linear(128, IMG_SIZE),
nn.Tanh()
) self.discriminator = nn.Sequential(
nn.Linear(IMG_SIZE, 128),
nn.LeakyReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Linear(128, 1),
nn.Sigmoid()
) def generator_forward(self, z):
img = self.generator(z)
return img def discriminator_forward(self, img):
pred = model.discriminator(img)
return pred.view(-1)

GAN—Model

start_time = time.time()

discr_costs = []
gener_costs = [] for epoch in range(num_epochs):
model = model.train()
for batch_idx, (features, targets) in enumerate(train_loader): features = (features - 0.5) * 2.
features = features.view(-1, IMG_SIZE).to(device)
targets = targets.to(device) # Adversarial ground truths
valid = torch.ones(targets.size(0)).float().to(device)
fake = torch.zeros(targets.size(0)).float().to(device) ### FORWARD AND BACK PROP # ---------------------
# Train Generator
# --------------------- # make new images
z = torch.zeros((targets.size(0), LATENT_DIM)).uniform_(-1.0, 1.0).to(device) # generate a batch of images
generated_features = model.generator_forward(z) # Loss measures generators's ability to fool the discriminator
discr_pred = model.discriminator_forward(generated_features)
gener_loss = F.binary_cross_entropy(discr_pred, valid) optim_gener.zero_grad()
gener_loss.backward()
optim_gener.step() # ---------------------
# Train Discriminator
# --------------------- # Measure discriminator's ability to classify real from samples
discr_pred_real = model.discriminator_forward(features.view(-1, IMG_SIZE))
real_loss = F.binary_cross_entropy(discr_pred_real, valid)
discr_pred_fake = model.discriminator_forward(generated_features.detach())
fake_loss = F.binary_cross_entropy(discr_pred_fake, fake)
discr_loss = 0.5 * (real_loss + fake_loss) optim_discr.zero_grad()
discr_loss.backward()
optim_discr.step() discr_costs.append(discr_loss)
gener_costs.append(gener_loss) ### LOGGING
if not batch_idx % 100:
print('Epoch: %03d/%03d | Batch %03d/%03d | Gen/Dis Loss: %.4f/%.4f'
%(epoch+1, num_epochs, batch_idx, len(train_loader), gener_loss, discr_loss)) print('Time elapsed: %.2f min' % ((time.time() - start_time)/60)) print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))

网络训练

画出 generator loss 和 discriminator loss 的变化图:

plt.plot(range(len(gener_costs)), gener_costs, label='generator loss')
plt.plot(range(len(discr_costs)), discr_costs, label='discriminator loss')
plt.legend()
plt.savefig('./loss.jpg')
plt.show()

利用以上训练的 Generator 生成一些仿手写数字图片:

#########################
## VISUALIZATION
######################### model.eval()
# Make new images
z = torch.zeros((5, LATENT_DIM)).uniform_(-1.0, 1.0).to(device)
generated_features = model.generator_forward(z)
imgs = generated_features.view(-1, 28, 28) fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(20, 2.5)) for i, ax in enumerate(axes):
axes[i].imshow(imgs[i].detach().numpy(), cmap='binary')

再生成几次:

可以发现,以上生成的数字图片有些很清晰,但有些很模糊,不易辨认,但是结果已经让人很兴奋了~~

后续可以对GAN进行改进,从而生成质量更高的图片。

Reference

  [1] deeplearning-models——Github

  [2] Paper《Generative Adversarial Network 

 

GAN——生成手写数字的更多相关文章

  1. GAN实战笔记——第三章第一个GAN模型:生成手写数字

    第一个GAN模型-生成手写数字 一.GAN的基础:对抗训练 形式上,生成器和判别器由可微函数表示如神经网络,他们都有自己的代价函数.这两个网络是利用判别器的损失记性反向传播训练.判别器努力使真实样本输 ...

  2. 卷积生成对抗网络(DCGAN)---生成手写数字

    深度卷积生成对抗网络(DCGAN) ---- 生成 MNIST 手写图片 1.基本原理 生成对抗网络(GAN)由2个重要的部分构成: 生成器(Generator):通过机器生成数据(大部分情况下是图像 ...

  3. Tensorflow:DCGAN生成手写数字

    参考地址:https://blog.csdn.net/miracle_ma/article/details/78305991 使用DCGAN(deep convolutional GAN):深度卷积G ...

  4. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  5. C#中调用Matlab人工神经网络算法实现手写数字识别

    手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化  投影  矩阵  目标定位  Matlab 手写数字图像识别简介: 手写 ...

  6. 基于opencv的手写数字识别(MFC,HOG,SVM)

    参考了秋风细雨的文章:http://blog.csdn.net/candyforever/article/details/8564746 花了点时间编写出了程序,先看看效果吧. 识别效果大概都能正确. ...

  7. 【机器学习】BP神经网络实现手写数字识别

    最近用python写了一个实现手写数字识别的BP神经网络,BP的推导到处都是,但是一动手才知道,会理论推导跟实现它是两回事.关于BP神经网络的实现网上有一些代码,可惜或多或少都有各种问题,在下手写了一 ...

  8. 深度学习-使用cuda加速卷积神经网络-手写数字识别准确率99.7%

    源码和运行结果 cuda:https://github.com/zhxfl/CUDA-CNN C语言版本参考自:http://eric-yuan.me/ 针对著名手写数字识别的库mnist,准确率是9 ...

  9. 利用神经网络算法的C#手写数字识别

    欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwritten_character_recognition.zip 下载源码 - 70. ...

随机推荐

  1. Ubuntu16.04的搭建l.2.t.p.d(宿舍访问公司内网)

    主要的实现步骤 openswan(ipsec) : 提供一个密钥 ppp :提供用户名和密码 xl2tpd : 提供L2TP服务 sysctl : 提供服务器内部转发 iptables : 提供请求从 ...

  2. iOS开发 简单实现视频音频的边下边播 (转)

      1.ios视频音频边缓存边播放,缓存时可以在已下载的部分拖拽进度条. 3.无论是下载到一半退出还是下载完退出,已缓存的数据都存到自己指定的一个路径.如果已下载完,下次播放时可以不再走网络,直接播放 ...

  3. 北京地铁出行线路规划系统项目总结(Java+Flask+Vue实现)

    北京地铁出行线路规划系统项目总结 GitHub仓库地址:https://github.com/KeadinZhou/SE-Subway Demo地址:http://10.66.2.161:8080/ ...

  4. Transformer模型---encoder

    一.简介 论文链接:<Attention is all you need> 由google团队在2017年发表于NIPS,Transformer 是一种新的.基于 attention 机制 ...

  5. 201871010118-唐敬博《面向对象程序设计(java)》第十一周学习总结

    在博客园撰写博客(随笔),总结10周学习内容,作业格式要求如下: 博文名称:学号-姓名<面向对象程序设计(java)>第十一周学习总结(1分) 博文正文开头格式:(2分) 项目 内容 这个 ...

  6. 201871010135-张玉晶《面向对象程序设计(java)》第十周学习总结

    201871010135-张玉晶<面向对象程序设计(java)>第十周学习总结 项目 内容 这个作业属于哪个课程 https://www.cnblogs.com/nwnu-daizh/ 这 ...

  7. 【大数据】0001---使用SparkSQL关联两个表求和取前几行

    场景: 有两个表,表可以是文本或Json数据,结构化后分别是Table1(A,B,C)和Table2(C.D.E),两个表通过C关联,要求求出D+E之和,并以(A.B.D+E)三列返回 解答: 思路: ...

  8. Dubbo支持的注册中心(二)

    1. Zookeeper 优点:支持网络集群 缺点:稳定性受限于 Zookeeper 2. Redis 优点:对服务器环境要求较高 缺点:对服务器环境要求较高 3. Multicast 优点:去中心化 ...

  9. scp、rsync、xsync

    scp. 拷贝完全相同 scp -r etc/hadoop/dfs.hosts root@192.168.121.134:/usr/local/hadoop/hadoop-2.7.6/etc/hado ...

  10. JDOJ 3055: Nearest Common Ancestors

    JDOJ 3055: Nearest Common Ancestors JDOJ传送门 Description 给定N个节点的一棵树,有K次查询,每次查询a和b的最近公共祖先. 样例中的16和7的公共 ...