参考资料

GAN原理学习笔记

生成式对抗网络GAN汇总

GAN的理解与TensorFlow的实现

TensorFlow小试牛刀(2):GAN生成手写数字

参考代码之一

#coding=utf-8
#http://blog.csdn.net/u012223913/article/details/75051516?locationNum=1&fps=1
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from tensorflow.examples.tutorials.mnist import input_data sess = tf.InteractiveSession() mb_size = 128
Z_dim = 100 mnist = input_data.read_data_sets('MNIST_data', one_hot=True) def weight_var(shape, name):
return tf.get_variable(name=name, shape=shape, initializer=tf.contrib.layers.xavier_initializer()) def bias_var(shape, name):
return tf.get_variable(name=name, shape=shape, initializer=tf.constant_initializer(0)) # discriminater net X = tf.placeholder(tf.float32, shape=[None, 784], name='X') # X [128 784] z1 = W * x + b
# 矩阵乘法 128 * 784 * 784 * 128 = [128 128]
D_W1 = weight_var([784, 128], 'D_W1')
D_b1 = bias_var([128], 'D_b1') #z2 = W * z1 + b
# 矩阵乘法 128 * 128 * 128 * 1 = [128 1] 输出判决结果,二分类
D_W2 = weight_var([128, 1], 'D_W2')
D_b2 = bias_var([1], 'D_b2'
)
theta_D = [D_W1, D_W2, D_b1, D_b2] # generator net Z = tf.placeholder(tf.float32, shape=[None, 100], name='Z') # z [128 784] z1 = W * x + b
# 矩阵乘法 128 * 100 * 100 * 128 = [128 128]
G_W1 = weight_var([100, 128], 'G_W1')
G_b1 = bias_var([128], 'G_B1') #z2 = W * z1 + b
# 矩阵乘法 128 * 128 * 128 * 784 = [128 784] 输出28*28的图像
G_W2 = weight_var([128, 784], 'G_W2')
G_b2 = bias_var([784], 'G_B2') theta_G = [G_W1, G_W2, G_b1, G_b2] def generator(z):
G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
G_prob = tf.nn.sigmoid(G_log_prob) return G_prob def discriminator(x):
D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
D_logit = tf.matmul(D_h1, D_W2) + D_b2
D_prob = tf.nn.sigmoid(D_logit)
return D_prob, D_logit G_sample = generator(Z)
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(G_sample) #discriminator输出为1表示ground truth
#discriminator输出为0表示非ground truth
#对于生成网络希望两点:
#(2)希望D_real尽可能大,这样保证正确识别真正的样本
#(1)希望D_fake尽可能小,这样可以剔除假的生成样本
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake)) #对于判别网络, 希望D_fake尽可能大,这样可以迷惑生成网络,
G_loss = -tf.reduce_mean(tf.log(D_fake)) D_optimizer = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_optimizer = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G) init = tf.initialize_all_variables()
saver = tf.train.Saver()
# 启动默认图
sess = tf.Session()
# 初始化
sess.run(init) def sample_Z(m, n):
'''Uniform prior for G(Z)'''
return np.random.uniform(-1., 1., size=[m, n]) def plot(samples):
fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05) for i, sample in enumerate(samples): # [i,samples[i]] imax=16
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r') return fig if not os.path.exists('out/'):
os.makedirs('out/') i = 0 for it in range(1000000):
if it % 1000 == 0:
samples = sess.run(G_sample, feed_dict={
Z: sample_Z(16, Z_dim)}) # 16*784
fig = plot(samples)
plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
i += 1
plt.close(fig) X_mb, _ = mnist.train.next_batch(mb_size)#ground truth _, D_loss_curr = sess.run([D_optimizer, D_loss], feed_dict={
X: X_mb, Z: sample_Z(mb_size, Z_dim)})
_, G_loss_curr = sess.run([G_optimizer, G_loss], feed_dict={
Z: sample_Z(mb_size, Z_dim)}) if it % 1000 == 0:
print('Iter: {}'.format(it))
print('D loss: {:.4}'.format(D_loss_curr))
print('G_loss: {:.4}'.format(G_loss_curr))
print()

参考代码之二

#http://blog.csdn.net/sparkexpert/article/details/70147409

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
#from skimage.io import imsave
import scipy
import os
import shutil img_height = 28
img_width = 28
img_size = img_height * img_width to_train = True
to_restore = False
output_path = "output" # 总迭代次数500
max_epoch = 500 h1_size = 150
h2_size = 300
z_size = 100
batch_size = 256 # generate (model 1)
def build_generator(z_prior):
w1 = tf.Variable(tf.truncated_normal([z_size, h1_size], stddev=0.1), name="g_w1", dtype=tf.float32)
b1 = tf.Variable(tf.zeros([h1_size]), name="g_b1", dtype=tf.float32)
h1 = tf.nn.relu(tf.matmul(z_prior, w1) + b1)
w2 = tf.Variable(tf.truncated_normal([h1_size, h2_size], stddev=0.1), name="g_w2", dtype=tf.float32)
b2 = tf.Variable(tf.zeros([h2_size]), name="g_b2", dtype=tf.float32)
h2 = tf.nn.relu(tf.matmul(h1, w2) + b2)
w3 = tf.Variable(tf.truncated_normal([h2_size, img_size], stddev=0.1), name="g_w3", dtype=tf.float32)
b3 = tf.Variable(tf.zeros([img_size]), name="g_b3", dtype=tf.float32)
h3 = tf.matmul(h2, w3) + b3
x_generate = tf.nn.tanh(h3)
g_params = [w1, b1, w2, b2, w3, b3]
return x_generate, g_params # discriminator (model 2)
def build_discriminator(x_data, x_generated, keep_prob):
# tf.concat
x_in = tf.concat([x_data, x_generated], 0)
w1 = tf.Variable(tf.truncated_normal([img_size, h2_size], stddev=0.1), name="d_w1", dtype=tf.float32)
b1 = tf.Variable(tf.zeros([h2_size]), name="d_b1", dtype=tf.float32)
h1 = tf.nn.dropout(tf.nn.relu(tf.matmul(x_in, w1) + b1), keep_prob)
w2 = tf.Variable(tf.truncated_normal([h2_size, h1_size], stddev=0.1), name="d_w2", dtype=tf.float32)
b2 = tf.Variable(tf.zeros([h1_size]), name="d_b2", dtype=tf.float32)
h2 = tf.nn.dropout(tf.nn.relu(tf.matmul(h1, w2) + b2), keep_prob)
w3 = tf.Variable(tf.truncated_normal([h1_size, 1], stddev=0.1), name="d_w3", dtype=tf.float32)
b3 = tf.Variable(tf.zeros([1]), name="d_b3", dtype=tf.float32)
h3 = tf.matmul(h2, w3) + b3
y_data = tf.nn.sigmoid(tf.slice(h3, [0, 0], [batch_size, -1], name=None))
y_generated = tf.nn.sigmoid(tf.slice(h3, [batch_size, 0], [-1, -1], name=None))
d_params = [w1, b1, w2, b2, w3, b3]
return y_data, y_generated, d_params #
def show_result(batch_res, fname, grid_size=(8, 8), grid_pad=5):
batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], img_height, img_width)) + 0.5
img_h, img_w = batch_res.shape[1], batch_res.shape[2]
grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1)
grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1)
img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)
for i, res in enumerate(batch_res):
if i >= grid_size[0] * grid_size[1]:
break
img = (res) * 255
img = img.astype(np.uint8)
row = (i // grid_size[0]) * (img_h + grid_pad)
col = (i % grid_size[1]) * (img_w + grid_pad)
img_grid[row:row + img_h, col:col + img_w] = img
#imsave(fname, img_grid)
#img.save('output/num.jpg')
scipy.misc.imsave(fname, img_grid) def train():
# load data(mnist手写数据集)
mnist = input_data.read_data_sets('mnist_data', one_hot=True) x_data = tf.placeholder(tf.float32, [batch_size, img_size], name="x_data")
z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")
keep_prob = tf.placeholder(tf.float32, name="keep_prob")
global_step = tf.Variable(0, name="global_step", trainable=False) # 创建生成模型
x_generated, g_params = build_generator(z_prior)
# 创建判别模型
y_data, y_generated, d_params = build_discriminator(x_data, x_generated, keep_prob) # 损失函数的设置
d_loss = - (tf.log(y_data) + tf.log(1 - y_generated))
g_loss = - tf.log(y_generated) optimizer = tf.train.AdamOptimizer(0.0001) # 两个模型的优化函数
d_trainer = optimizer.minimize(d_loss, var_list=d_params)
g_trainer = optimizer.minimize(g_loss, var_list=g_params) init = tf.initialize_all_variables() saver = tf.train.Saver()
# 启动默认图
sess = tf.Session()
# 初始化
sess.run(init) if to_restore:
chkpt_fname = tf.train.latest_checkpoint(output_path)
saver.restore(sess, chkpt_fname)
else:
if os.path.exists(output_path):
shutil.rmtree(output_path)
os.mkdir(output_path) z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32) steps = 60000 / batch_size
for i in range(sess.run(global_step), max_epoch):
for j in np.arange(steps):
# for j in range(steps):
print("epoch:%s, iter:%s" % (i, j))
# 每一步迭代,我们都会加载256个训练样本,然后执行一次train_step
x_value, _ = mnist.train.next_batch(batch_size)
x_value = 2 * x_value.astype(np.float32) - 1
z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
# 执行生成
sess.run(d_trainer,
feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})
# 执行判别
if j % 1 == 0:
sess.run(g_trainer,
feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})
x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_sample_val})
show_result(x_gen_val, "output/sample{0}.jpg".format(i))
z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_random_sample_val})
show_result(x_gen_val, "output/random_sample{0}.jpg".format(i))
sess.run(tf.assign(global_step, i + 1))
saver.save(sess, os.path.join(output_path, "model"), global_step=global_step) def test():
z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")
x_generated, _ = build_generator(z_prior)
chkpt_fname = tf.train.latest_checkpoint(output_path) init = tf.initialize_all_variables()
sess = tf.Session()
saver = tf.train.Saver()
sess.run(init)
saver.restore(sess, chkpt_fname)
z_test_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_test_value})
show_result(x_gen_val, "output/test_result.jpg") if __name__ == '__main__':
if to_train:
train()
else:
test()

GAN 生成mnist数据的更多相关文章

  1. GAN生成式对抗网络(三)——mnist数据生成

    通过GAN生成式对抗网络,产生mnist数据 引入包,数据约定等 import numpy as np import matplotlib.pyplot as plt import input_dat ...

  2. 4.keras实现-->生成式深度学习之用变分自编码器VAE生成图像(mnist数据集和名人头像数据集)

    变分自编码器(VAE,variatinal autoencoder)   VS    生成式对抗网络(GAN,generative adversarial network) 两者不仅适用于图像,还可以 ...

  3. 深度学习之 GAN 进行 mnist 图片的生成

    深度学习之 GAN 进行 mnist 图片的生成 mport numpy as np import os import codecs import torch from PIL import Imag ...

  4. 对抗生成网络-图像卷积-mnist数据生成(代码) 1.tf.layers.conv2d(卷积操作) 2.tf.layers.conv2d_transpose(反卷积操作) 3.tf.layers.batch_normalize(归一化操作) 4.tf.maximum(用于lrelu) 5.tf.train_variable(训练中所有参数) 6.np.random.uniform(生成正态数据

    1. tf.layers.conv2d(input, filter, kernel_size, stride, padding) # 进行卷积操作 参数说明:input输入数据, filter特征图的 ...

  5. GAN︱生成模型学习笔记(运行机制、NLP结合难点、应用案例、相关Paper)

    我对GAN"生成对抗网络"(Generative Adversarial Networks)的看法: 前几天在公开课听了新加坡国立大学[机器学习与视觉实验室]负责人冯佳时博士在[硬 ...

  6. tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)

    tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...

  7. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  8. GAN生成的评价指标 Evaluation of GAN

    传统方法中,如何衡量一个generator ?-- 用 generator 产生数据的 likelihood,越大越好. 但是 GAN 中的 generator 是隐式建模,所以只能从 P_G 中采样 ...

  9. Enterprise Solution 生成实体数据访问接口与实现类型 Code Smith 6.5 模板文件下载

    数据库表定义为SalesOrder,用LLBL Gen Pro生成的实体定义是SalesOrderEntity,再用Code Smith生成的数据读写接口是ISalesOrderManager,最后是 ...

随机推荐

  1. GuildBrowser使用AF+MVC 学习笔记

    GuildBrowser 是一个 测试用的项目此为 魔兽世界api的一个展示客户端 项目地址:https://github.com/yehai/GuildBrowser 一:所使用的设计模式:MVC ...

  2. 判断IE浏览器版本的精简脚本

    IE浏览器不管是什么版本,总是跟Web标准有些不太兼容.对于代码工作者来说,自然是苦不堪言,为了考虑IE的兼容问题,不管是写 CSS 还是 JS,往往都要对 IE 特别对待,这就少不了做些判断.本文不 ...

  3. android开机启动代码

    1)public class StartupReceiver extends BroadcastReceiver { @Override public void onReceive(Context c ...

  4. Mybatis解决字段名与实体类属性名不相同的冲突

    在平时的开发中,我们表中的字段名和表对应实体类的属性名称不一定都是完全相同的,下面来演示一下这种情况下的如何解决字段名与实体类属性名不相同的冲突. 一.准备演示需要使用的表和数据 CREATE TAB ...

  5. linux cut 命令

    cut:以某种方式按照文件的行进行分割 参数列表: -b 按字节选取 忽略多字节字符边界,除非也指定了 -n 标志 -c 按字符选取 -d 自定义分隔符,默认为制表符. -f 与-d一起使用,指定显示 ...

  6. springmvc验证数据

    1.引入jar包 com.springsource.javax.validation-1.0.0.GA.jar  规范(只是定义) hibernate-validator-4.1.0.Final.ja ...

  7. iOS8的一些控件的变更

    UISearchDisplayController变更为UISearchController UIAlertView变更为UIAlertController 如果添加点击事件则需要使用UIAlertC ...

  8. 随笔小问题(一)--mac打开class文件

    本来不想写这个东西的.但是这个却费了我一番周折. 我要先声明一点的是,我从来不讲iOS当成一个单独的系统,而是将这个操作系统归位unix内核的系统. 简单来说,我把它当成linux在用. 但是,mac ...

  9. java反射详解及说明

    首先写一个Person类: package lltse.base.reflectdemo; public class Person { private String name ="张三&qu ...

  10. hibernate 中多对多关系对象集合的保存

    多对多关系映射和一对多关系映射开发步骤差不多, 例子如下:员工和项目之间的关系,一个员工可以参与多个项目:一个项目可以有多个开发人员参与.因此是多对多的关系. 1 分析数据表 1.1)员工表 CREA ...