概述:在前期的文章中,我们用TensorFlow完成了对手写数字的识别,得到了94.09%的识别准确度,效果还算不错。在这篇文章中,笔者将带领大家用GAN模型,生成我们想要的手写数字。

GAN简介

对抗性生成网络(GenerativeAdversarial Network),由 Ian Goodfellow 首先提出,由两个网络组成,分别是generator网络(用于生成)和discriminator网络(用于判别)。GAN网络的目的就是使其自己生成一副图片,比如说经过对一系列猫的图片的学习,generator网络可以自己“绘制”出一张猫的图片,且尽量真实。discriminator网络则是用来进行判断的,将一张真实的图片和一张由generator网络生成的照片同时交给discriminator网络,不断训练discriminator网络,使其可以准确将discriminator网络生成的“假图片”找出来。就这样,generator网络不断改进使其可以骗过discriminator网络,而discriminator网络不断改进使其可以更准确找到“假图片”,这种相互促进相互对抗的关系,就叫做对抗网络。图一中展示了GAN模型的结构。

思路梳理

将MNIST数据集中标签为0的图片提取出来,然后训练discriminator网络,进行手写数字0识别,接着让generator产生一张随机图片,让训练好的discriminator去识别这张生成的图片,不断训练discriminator,直到discriminator网络将生成的图片当做数字0为止。

生成“假图片

生成一张随机像素的28*28的图片,分别进行全连接,Leaky ReLU函数激活,dropout处理(随机丢弃一些神经元,防止过拟合),全连接,tanh函数激活,最终生成一张“假图片”,TensorFlow代码如下:

1def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
2    with tf.variable_scope("generator", reuse=reuse):
3        hidden1 = tf.layers.dense(noise_img, n_units)  # 全连接层
4        hidden1 = tf.maximum(alpha * hidden1, hidden1)
5        hidden1 = tf.layers.dropout(hidden1, rate=0.2)
6        logits = tf.layers.dense(hidden1, out_dim)
7        outputs = tf.tanh(logits)
8        return logits, outputs

图像判别

将需要进行判别的图片先后经过全连接,Leaky ReLU函数激活,全连接,sigmoid函数激活处理,最终输出图片的识别结果,TensorFlow代码如下:

1def get_discriminator(img, n_units, reuse=False, alpha=0.01):
2    with tf.variable_scope("discriminator", reuse=reuse):
3        hidden1 = tf.layers.dense(img, n_units)
4        hidden1 = tf.maximum(alpha * hidden1, hidden1)
5        logits = tf.layers.dense(hidden1, 1)
6        outputs = tf.sigmoid(logits)
7        return logits, outputs

完整代码

GAN手写数字识别的完整代码如下:

  1import tensorflow as tf
 2from tensorflow.examples.tutorials.mnist import input_data
 3import matplotlib.pyplot as plt
 4import numpy as np
 5
 6mnist = input_data.read_data_sets("E:/Tensor/MNIST_data/")
 7img = mnist.train.images[50]
 8
 9
10def get_inputs(real_size, noise_size):
11    real_img = tf.placeholder(tf.float32, [None, real_size], name="real_img")
12    noise_img = tf.placeholder(tf.float32, [None, noise_size], name="noise_img")
13    return real_img, noise_img
14
15
16# 生成图像
17def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
18    with tf.variable_scope("generator", reuse=reuse):
19        hidden1 = tf.layers.dense(noise_img, n_units)  # 全连接层
20        hidden1 = tf.maximum(alpha * hidden1, hidden1)
21        hidden1 = tf.layers.dropout(hidden1, rate=0.2)
22        logits = tf.layers.dense(hidden1, out_dim)
23        outputs = tf.tanh(logits)
24        return logits, outputs
25
26
27# 图像判别
28def get_discriminator(img, n_units, reuse=False, alpha=0.01):
29    with tf.variable_scope("discriminator", reuse=reuse):
30        hidden1 = tf.layers.dense(img, n_units)
31        hidden1 = tf.maximum(alpha * hidden1, hidden1)
32        logits = tf.layers.dense(hidden1, 1)
33        outputs = tf.sigmoid(logits)
34        return logits, outputs
35#真实图像size
36img_size = mnist.train.images[0].shape[0]
37#传入generator的噪声size
38noise_size = 100
39#生成器隐层参数
40g_units = 128
41#判别器隐层参数
42d_units = 128
43#Leaky ReLU参数
44alpha = 0.01
45#学习率
46learning_rate = 0.001
47#label smoothing
48smooth = 0.1
49tf.reset_default_graph()
50real_img, noise_img = get_inputs(img_size, noise_size)
51g_logits, g_outputs = get_generator(noise_img, g_units, img_size)
52
53d_logits_real, d_outputs_real = get_discriminator(real_img, d_units)
54d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, d_units, reuse=True)
55
56d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
57    logits=d_logits_real, labels=tf.ones_like(d_logits_real)
58) * (1 - smooth))
59d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
60    logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake)
61))
62d_loss = tf.add(d_loss_real, d_loss_fake)
63g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
64    logits=d_logits_fake, labels=tf.ones_like(d_logits_fake)
65) * (1 - smooth))
66
67train_vars = tf.trainable_variables()
68g_vars = [var for var in train_vars if var.name.startswith("generator")]
69d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
70
71d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
72g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
73
74
75epochs = 10000
76samples = []
77n_sample = 10
78losses = []
79
80i = j = 0
81while i<10000:
82    if mnist.train.labels[j] == 0:
83        samples.append(mnist.train.images[j])
84        i += 1
85    j += 1
86
87print(len(samples))
88size = samples[0].size
89
90with tf.Session() as sess:
91    tf.global_variables_initializer().run()
92    for e in range(epochs):
93        batch_images = samples[e] * 2 -1
94        batch_noise = np.random.uniform(-1, 1, size=noise_size)
95
96        _ = sess.run(d_train_opt, feed_dict={real_img:[batch_images], noise_img:[batch_noise]})
97        _ = sess.run(g_train_opt, feed_dict={noise_img:[batch_noise]})
98
99    sample_noise = np.random.uniform(-1, 1, size=noise_size)
100    g_logit, g_output = sess.run(get_generator(noise_img, g_units, img_size,
101                                         reuse=True), feed_dict={
102        noise_img:[sample_noise]
103    })
104    print(g_logit.size)
105    g_output = (g_output+1)/2
106    plt.imshow(g_output.reshape([28, 28]), cmap='Greys_r')
107    plt.show()

训练效果

在经过了10000次的迭代后,generator网络生成的图片已经接近手写数字零的形状。

本文是对GAN模型的初次探索,在后续GAN模型的系列文章中,笔者将层层深入的去讲解GAN模型复杂的应用。

GAN模型生成手写字的更多相关文章

  1. 使用生成对抗网络(GAN)生成手写字

    先放结果 这是通过GAN迭代训练30W次,耗时3小时生成的手写字图片效果,大部分的还是能看出来是数字的. 实现原理 简单说下原理,生成对抗网络需要训练两个任务,一个叫生成器,一个叫判别器,如字面意思, ...

  2. GAN实战笔记——第三章第一个GAN模型:生成手写数字

    第一个GAN模型-生成手写数字 一.GAN的基础:对抗训练 形式上,生成器和判别器由可微函数表示如神经网络,他们都有自己的代价函数.这两个网络是利用判别器的损失记性反向传播训练.判别器努力使真实样本输 ...

  3. 用TensorFlow教你手写字识别

    博主原文链接:用TensorFlow教你做手写字识别(准确率94.09%) 如需转载,请备注出处及链接,谢谢. 2012 年,Alex Krizhevsky, Geoff Hinton, and Il ...

  4. [Deep-Learning-with-Python]GAN图片生成

    GAN 由Goodfellow等人于2014年引入的生成对抗网络(GAN)是用于学习图像潜在空间的VAE的替代方案.它们通过强制生成的图像在统计上几乎与真实图像几乎无法区分,从而能够生成相当逼真的合成 ...

  5. tensorflow卷积神经网络与手写字识别

    1.知识点 """ 基础知识: 1.神经网络(neural networks)的基本组成包括输入层.隐藏层.输出层.而卷积神经网络的特点在于隐藏层分为卷积层和池化层(po ...

  6. 【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像

    前言 深度学习作为人工智能的重要手段,迎来了爆发,在NLP.CV.物联网.无人机等多个领域都发挥了非常重要的作用.最近几年,各种深度学习算法层出不穷, Generative Adverarial Ne ...

  7. VS2010 根据模型生成数据库 打开edmx.sql文件时 vs出现无响应的解决方案

    今天在VS2010 sp1+sql server 2008 R2+Win7操作系统下测试ADO.NET 实体数据模型时 ,遇到这样一个问题. 首先建好实体模型,然后"根据模型生成数据库&qu ...

  8. 根据powerdesigner的OO模型生成C#代码

    2007-05-15 08:34:11|  分类: 转贴部分 |  标签:学习帖子 |字号 订阅 习惯了用Powerdesigner设计数据库模型,XDE设计类图.因此我一般的设计方法是用PD做分析模 ...

  9. MySQL Workbench将模型生成SQL文件出错

    采用MySQL Workbench 设计好表和表关系后,从 File | Export 菜单中,选择 Forward Engineer SQL CREATE Script(正向引擎), 将我们的模型生 ...

随机推荐

  1. mybatis逆向工程的注意事项,以及数据库表

    1.选择性更新,如果有新参数就更换成新参数,如果参数是null就不更新,还是原来的参数 2.mybatis使用逆向工程,数据库建表的字段user_id必须用下滑线隔开,这样生成的对象private L ...

  2. 关于Django字段类型中 blank和null的区别

    blank 设置为True时,字段可以为空.设置为False时,字段是必须填写的.字符型字段CharField和TextField是用空字符串来存储空值的. 如果为True,字段允许为空,默认不允许. ...

  3. js-day05-JSON-jQuery初体验

    JSON数据格式 JSON(JavaScript Object Notation)一种简单的数据格式,比xml更轻巧.易于人阅读和编写,同时也易于机器解析和生成(网络传输速度快)JSON是JavaSc ...

  4. Openstack中RabbitMQ RPC代码分析

    在Openstack中,RPC调用是通过RabbitMQ进行的. 任何一个RPC调用,都有Client/Server两部分,分别在rpcapi.py和manager.py中实现. 这里以nova-sc ...

  5. TCP协议学习总结(下)

    在前两边TCP学习总结中,也大概地学习了TCP的整个流程,但许多细节中的细节并没有详细学习,例如超时重传问题,每次瓶颈回归慢启动效率问题以及最大窗口限制问题等.本学习篇章最要针对这些细节中的细节进行学 ...

  6. 贪心算法----区间覆盖问题(POJ2376)

    题目: 题目的大概意思是约翰这个农民有N条牛,这些牛可以在一天中的某个时间段可以进行工作,他想把这个时间段分成若干个片段让这些牛去进行打扫任务,你的任务是安排尽量少的牛然后可以完成分成这些片段的打扫任 ...

  7. S-CMS企业建站v3几处SQL注入

    0x01 前言 有段时间没有发文章了,主要没挖到比较有意思的漏洞点.然后看最近爆了很多关于S-CMS的漏洞,下载了源码简单挖了一下然后给大家分享一下. 0x02 目录 Wap_index.php sq ...

  8. [Swift]LeetCode142. 环形链表 II | Linked List Cycle II

    Given a linked list, return the node where the cycle begins. If there is no cycle, return null. Note ...

  9. [Swift]LeetCode441. 排列硬币 | Arranging Coins

    You have a total of n coins that you want to form in a staircase shape, where every k-th row must ha ...

  10. [Swift]LeetCode880. 索引处的解码字符串 | Decoded String at Index

    An encoded string S is given.  To find and write the decodedstring to a tape, the encoded string is ...