2019-3-10——生成对抗网络GAN---生成mnist手写数字图像
"""
生成对抗网络(GAN,Generative Adversarial Networks)的基本原理很简单:
假设有两个网络,生成网络G和判别网络D。生成网络G接受一个随机的噪声z并生成图片,
记为G(z);判别网络D的作用是判别一张图片x是否真实,对于输入x,D(x)是x为真实图片的概率。
在训练过程中, 生成器努力让生成的图片更加真实从而使得判别器无法辨别图像的真假,
而D的目标就是尽量把分辨出真实图片和生成网络G产出的图片,这个过程就类似于二人博弈,
G和D构成了一个动态的“博弈过程”。随着时间的推移,生成器和判别器在不断地进行对抗,
最终两个网络达到一个动态平衡:生成器生成的图像G(z)接近于真实图像分布,而判别器识别不出真假图像,
即D(G(z))=0.5。最后,我们就可以得到一个生成网络G,用来生成图片。
"""
import tensorflow as tf
from matplotlib import pyplot as plt
import os
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets('/MNIST_data/',one_hot=True)
batch_size=64
units_size=128
learning_rate=0.001
epoch=300
smooth=0.1
"""定义生成模型"""
def generatorModel(noise_img,units_size,out_size,alpha=0.01):
"""生成器的目的是:对于生成的图片,G希望D打上标签1"""
with tf.variable_scope('generator'):
FC=tf.layers.dense(noise_img,units_size)
relu=tf.nn.leaky_relu(FC,alpha)
drop=tf.layers.dropout(relu,rate=0.2)
logits=tf.layers.dense(drop,out_size)
outputs=tf.tanh(logits)
return logits,outputs """定义判别模型"""
def discriminatorModel(images,unite_size,alpha=0.01,reuse=False):
"""
判别器的目的是:
1. 对于真实图片,D要为其打上标签1
2. 对于生成图片,D要为其打上标签0
"""
with tf.variable_scope('discriminator',reuse=reuse):
FC=tf.layers.dense(images,units_size)
relu=tf.nn.leaky_relu(FC,alpha)
logits=tf.layers.dense(relu,1)
outputs=tf.sigmoid(logits)
return logits,outputs
"""定义损失函数"""
def loss_fenction(real_logits,fake_logits,smooth):
"""生成器希望判别器判别出来的标签为1; tf.ones_like()创建一个将所有元素都设置为1的张量"""
G_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=fake_logits,
labels=tf.ones_like(fake_logits)*(1-smooth))
)
"""判别器识别生成器产出的图片,希望识别出来的标签为0;tf.zeros_like()创建一个将所有元素都设置为0的张量"""
fake_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=fake_logits,
labels=tf.zeros_like(fake_logits))
)
"""判别器判别真实图片,希望判别出来的标签为1;tf.ones_like()创建一个将所有元素都设置为1的张量"""
real_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=real_logits,
labels=tf.ones_like(real_logits)*(1-smooth))
)
D_loss=tf.add(fake_loss,real_loss)
return G_loss,fake_loss,real_loss,D_loss
"""定义优化器"""
def optimizer(G_loss,D_loss,learning_rate):
"""因为GAN中一共训练了两个网络,所以分别对G和D进行优化"""
train_var=tf.trainable_variables() #需要训练的变量
G_var=[var for var in train_var if var.name.startswith('generator')]
D_var=[var for var in train_var if var.name.startswith('discriminator')]
G_optimizer=tf.train.AdadeltaOptimizer(learning_rate).minimize(G_loss,var_list=G_var)
D_optimizer=tf.train.AdadeltaOptimizer(learning_rate).minimize(D_loss,var_list=D_var)
return G_optimizer,D_optimizer
"""训练"""
def train(mnist):
image_size = mnist.train.images[0].shape[0]
real_images = tf.placeholder(tf.float32,[None,image_size])
fake_images = tf.placeholder(tf.float32,[None,image_size])
"""调用生成模型生成假图片G_output"""
G_logits,G_output = generatorModel(fake_images,units_size,image_size)
"""D对真实图像的判别"""
real_logits,real_output = discriminatorModel(real_images,units_size)
"""D对G生成图像的判别"""
fake_logits,fake_output=discriminatorModel(G_output,units_size,reuse=True)
G_loss,real_loss,fake_loss,D_loss=loss_fenction(real_logits,fake_logits,smooth)
G_optimizer,D_optimizer=optimizer(G_loss,D_loss,learning_rate) saver=tf.train.Saver()
step=0
with tf.Session() as session:
session.run(tf.global_variables_initializer())
for Epoch in range(epoch):
for batch_i in range(mnist.train.num_examples//batch_size):
batch_image,_=mnist.train.next_batch(batch_size)
"""对图像像素进行scale,tanh的输出结果为(-1,1)"""
batch_image=batch_image*2-1
"""模型的输入噪声"""
noise_image=np.random.uniform(-1,1,size=(batch_size,image_size))#从均匀分布[-1,1)中随机采样
session.run(G_optimizer,feed_dict={fake_images:noise_image})
session.run(D_optimizer,feed_dict={real_images:batch_image,fake_images:noise_image})
step=step+1
loss_D= session.run(D_loss, feed_dict={real_images: batch_image, fake_images: noise_image})
loss_real= session.run(real_loss, feed_dict={real_images: batch_image, fake_images: noise_image})
loss_fake= session.run(fake_loss, feed_dict={real_images: batch_image, fake_images: noise_image})
loss_G= session.run(G_loss, feed_dict={fake_images: noise_image})
print('epoch:', Epoch, 'loss_D:', loss_D,'loss_real', loss_real,'loss_fake', loss_fake, 'loss_G', loss_G)
model_path=os.getcwd()+os.sep+"mnist.model"
saver.save(session,model_path,global_step=step)
"""定义主函数"""
def main(argv=None):
train(mnist)
if __name__ =='__main__':
tf.app.run()
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
import pickle
import example88_0 UNITS_SIZE = example88_0.units_size def generatorImage(image_size):
sample_images = tf.placeholder(tf.float32, [None, image_size])
G_logits, G_output = example88_0.generatorModel(sample_images, UNITS_SIZE, image_size)
saver = tf.train.Saver()
with tf.Session() as session:
session.run(tf.global_variables_initializer())
saver.restore(session, tf.train.latest_checkpoint('.'))
sample_noise = np.random.uniform(-1, 1, size=(25, image_size))
samples = session.run(G_output, feed_dict={sample_images: sample_noise})
with open('samples.pkl', 'wb') as f:
pickle.dump(samples, f) def show():
with open('samples.pkl', 'rb') as f:
samples = pickle.load(f)
fig, axes = plt.subplots(figsize=(7, 7), nrows=5, ncols=5, sharey=True, sharex=True)
for ax, image in zip(axes.flatten(), samples):
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
ax.imshow(image.reshape((28, 28)), cmap='Greys_r')
plt.show() def main(argv=None):
image_size = example88_0.mnist.train.images[0].shape[0]
generatorImage(image_size)
show() if __name__ == '__main__':
tf.app.run()

2019-3-10——生成对抗网络GAN---生成mnist手写数字图像的更多相关文章
- 用MXNet实现mnist的生成对抗网络(GAN)
用MXNet实现mnist的生成对抗网络(GAN) 生成式对抗网络(Generative Adversarial Network,简称GAN)由一个生成网络与一个判别网络组成.生成网络从潜在空间(la ...
- 人工智能中小样本问题相关的系列模型演变及学习笔记(二):生成对抗网络 GAN
[说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![握手][握手] [再啰嗦一下]本文衔接上一个随笔:人工智能中小样本问题相关的系列模型演变及学习 ...
- 生成对抗网络GAN介绍
GAN原理 生成对抗网络GAN由生成器和判别器两部分组成: 判别器是常规的神经网络分类器,一半时间判别器接收来自训练数据中的真实图像,另一半时间收到来自生成器中的虚假图像.训练判别器使得对于真实图像, ...
- TensorFlow从1到2(十二)生成对抗网络GAN和图片自动生成
生成对抗网络的概念 上一篇中介绍的VAE自动编码器具备了一定程度的创造特征,能够"无中生有"的由一组随机数向量生成手写字符的图片. 这个"创造能力"我们在模型中 ...
- 深度学习-生成对抗网络GAN笔记
生成对抗网络(GAN)由2个重要的部分构成: 生成器G(Generator):通过机器生成数据(大部分情况下是图像),目的是“骗过”判别器 判别器D(Discriminator):判断这张图像是真实的 ...
- 深度学习框架PyTorch一书的学习-第七章-生成对抗网络(GAN)
参考:https://github.com/chenyuntc/pytorch-book/tree/v1.0/chapter7-GAN生成动漫头像 GAN解决了非监督学习中的著名问题:给定一批样本,训 ...
- 生成对抗网络(GAN)
基本思想 GAN全称生成对抗网络,是生成模型的一种,而他的训练则是处于一种对抗博弈状态中的. 譬如:我要升职加薪,你领导力还不行,我现在领导力有了要升职加薪,你执行力还不行,我现在执行力有了要升职加薪 ...
- 科普 | 生成对抗网络(GAN)的发展史
来源:https://en.wikipedia.org/wiki/Edmond_de_Belamy 五年前,Generative Adversarial Networks(GANs)在深度学习领域掀起 ...
- 利用tensorflow训练简单的生成对抗网络GAN
对抗网络是14年Goodfellow Ian在论文Generative Adversarial Nets中提出来的. 原理方面,对抗网络可以简单归纳为一个生成器(generator)和一个判断器(di ...
随机推荐
- Kali 局域网 DNS 劫持
<一> 所需工具 1: Kali-linux-2017 2: ettercap 0.8.2 3: web 服务器, 这里以 node 为例 <二> 原理 1: DNS劫持 ...
- leetcode993
public class Node { public int CurNode; public int FatherNode; public int Layer; } public class Solu ...
- Django中MEDIA_ROOT和MEDIA_URL
在django上传图片前端使用动态的配置方法 MEDIA_ROOT 代表着 要上传的路径会和你在models中写的上传的路径进行拼节形成最终文件上传的路径 MEDIA_URL主要就是映射了 在前端使用 ...
- oracle删除表数据的两种的方式
转自:https://blog.csdn.net/qq_37840993/article/details/82490787 平时写sql中我们都会用到删除语句,而平时删除表数据的时候我们经常会用到两种 ...
- 机器学习入门-文本数据-构造Tf-idf词袋模型(词频和逆文档频率) 1.TfidfVectorizer(构造tf-idf词袋模型)
TF-idf模型:TF表示的是词频:即这个词在一篇文档中出现的频率 idf表示的是逆文档频率, 即log(文档的个数/1+出现该词的文档个数) 可以看出出现该词的文档个数越小,表示这个词越稀有,在这 ...
- 尚硅谷redis学习3-redis启动以后的杂项
redis速度很快,运行benchmark可以看出,各项运行速度可达100000次每秒 redis默认有16个数据库,分别是0, 1 ... 15,默认在0号库,可以通过select num转到其它库 ...
- img标签在div里上下居中
方法一:图片尺寸未知,IE8-不支持 CSS部分: <style> .content{ width:500px; height:500px; border:1px solid black; ...
- gitlab jenkins 自动构建
工作中有这样一种需求: 每次提交代码之后,都自动执行 单元测试脚本,进行单元测试 jenkins监听项目的某个分支,设置运行脚本,设置一个url作为回调 利用gitlab的钩子,在每次有提交之后,触发 ...
- 100个常用的Linux命令——转载
1,echo “aa” > test.txt 和 echo “bb” >> test.txt //>将原文件清空,并且内容写入到文件中,>>将内容放到文件的尾部 2 ...
- C#中Graphics的画图代码【转】
我要写多几个字上去 string str = "Baidu"; //写什么字? Font font = Font("宋体",30f); //字是什么样子的? B ...