通过GAN生成式对抗网络,产生mnist数据

引入包,数据约定等

import numpy as np
import matplotlib.pyplot as plt
import input_data #读取数据的一个工具文件,不影响理解
import tensorflow as tf # 获取数据
mnist = input_data.read_data_sets('data/', one_hot=True)
trainimg = mnist.train.images X = mnist.train.images[:, :]
batch_size = 64 #用来返回真实数据
def iterate_minibatch(x, batch_size, shuffle=True):
indices = np.arange(x.shape[0])
if shuffle:
np.random.shuffle(indices)
for i in range(0, x.shape[0]-1000, batch_size):
temp = x[indices[i:i + batch_size], :]
temp = np.array(temp) * 2 - 1
yield np.reshape(temp, (-1, 28, 28, 1))

GAN对象结构

class GAN(object):
def __init__(self):
#初始函数,在这里对初始化模型
def netG(self, z):
#生成器模型
def netD(self, x, reuse=False):
#判别器模型

生成器函数

对随机值z(维度为1,100),进行包装,伪造,产生伪造数据。

包装过程概括为:全连接->reshape->反卷积

包装过程中使用了batch_normalization,Leaky ReLU,dropout,tanh等技巧

   #对随机值z(维度为1,100),进行包装,伪造,产生伪造数据。
#包装过程概括为:全连接->reshape->反卷积
#包装过程中使用了batch_normalization,Leaky ReLU,dropout,tanh等技巧
def netG(self,z,alpha=0.01):
with tf.variable_scope('generator') as scope:
layer1 = tf.layers.dense(z, 4 * 4 * 512) # 这是一个全连接层,输出 (n,4*4*512)
layer1 = tf.reshape(layer1, [-1, 4, 4, 512])
# batch normalization
layer1 = tf.layers.batch_normalization(layer1, training=True) # 做BN标准化处理
# Leaky ReLU
layer1 = tf.maximum(alpha * layer1, layer1)
# dropout
layer1 = tf.nn.dropout(layer1, keep_prob=0.8) # 4 x 4 x 512 to 7 x 7 x 256
layer2 = tf.layers.conv2d_transpose(layer1, 256, 4, strides=1, padding='valid')
layer2 = tf.layers.batch_normalization(layer2, training=True)
layer2 = tf.maximum(alpha * layer2, layer2)
layer2 = tf.nn.dropout(layer2, keep_prob=0.8) # 7 x 7 256 to 14 x 14 x 128
layer3 = tf.layers.conv2d_transpose(layer2, 128, 3, strides=2, padding='same')
layer3 = tf.layers.batch_normalization(layer3, training=True)
layer3 = tf.maximum(alpha * layer3, layer3)
layer3 = tf.nn.dropout(layer3, keep_prob=0.8) # 14 x 14 x 128 to 28 x 28 x 1
logits = tf.layers.conv2d_transpose(layer3, 1, 3, strides=2, padding='same')
# MNIST原始数据集的像素范围在0-1,这里的生成图片范围为(-1,1)
# 因此在训练时,记住要把MNIST像素范围进行resize
outputs = tf.tanh(logits) return outputs

判别器函数

通过深度卷积+全连接的形式,判别器将输入分类为真数据,还是假数据。

    def netD(self, x, reuse=False,alpha=0.01):
with tf.variable_scope('discriminator') as scope:
if reuse:
scope.reuse_variables()
layer1 = tf.layers.conv2d(x, 128, 3, strides=2, padding='same')
layer1 = tf.maximum(alpha * layer1, layer1)
layer1 = tf.nn.dropout(layer1, keep_prob=0.8) # 14 x 14 x 128 to 7 x 7 x 256
layer2 = tf.layers.conv2d(layer1, 256, 3, strides=2, padding='same')
layer2 = tf.layers.batch_normalization(layer2, training=True)
layer2 = tf.maximum(alpha * layer2, layer2)
layer2 = tf.nn.dropout(layer2, keep_prob=0.8) # 7 x 7 x 256 to 4 x 4 x 512
layer3 = tf.layers.conv2d(layer2, 512, 3, strides=2, padding='same')
layer3 = tf.layers.batch_normalization(layer3, training=True)
layer3 = tf.maximum(alpha * layer3, layer3)
layer3 = tf.nn.dropout(layer3, keep_prob=0.8) # 4 x 4 x 512 to 4*4*512 x 1
flatten = tf.reshape(layer3, (-1, 4 * 4 * 512))
f = tf.layers.dense(flatten, 1)
return f

初始化函数

有一个前置训练,将真实数据喂给判别器,训练判别器的鉴别能力

    # 有一个前置训练,将真实数据喂给判别器,训练判别器的鉴别能力
def __init__(self):
self.z = tf.placeholder(tf.float32, shape=[batch_size, 100], name='z') # 随机输入值
self.x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1], name='real_x') # 图片值 self.fake_x = self.netG(self.z) # 将随机输入,包装为伪造图片值 self.pre_logits = self.netD(self.x, reuse=False) # 判别器预训练时,判别器对真实数据的判别情况-未sigmoid处理
self.real_logits = self.netD(self.x, reuse=True) # 判别器对真实数据的判别情况-未sigmoid处理
self.fake_logits = self.netD(self.fake_x, reuse=True) # 判别器对伪造数据的判别情况-未sigmoid处理 # 预训练时判别器,判别器将真实数据判定为真的得分情况。
self.loss_pre_D = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.pre_logits,
labels=tf.ones_like(self.pre_logits)))
# 训练时,判别器将真实数据判定为真,将伪造数据判定为假的得分情况。
self.loss_D = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.real_logits,
labels=tf.ones_like(self.real_logits))) + \
tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
labels=tf.zeros_like(self.fake_logits)))
# 训练时,生成器伪造的数据,被判定为真实数据的得分情况。
self.loss_G = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
labels=tf.ones_like(self.fake_logits))) # 获取生成器和判定器对应的变量地址,用于更新变量
t_vars = tf.trainable_variables()
self.g_vars = [var for var in t_vars if var.name.startswith("generator")]
self.d_vars = [var for var in t_vars if var.name.startswith("discriminator")]

开始训练

gan = DCGAN()
#预训练时的梯度优化函数
d_pre_optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.loss_pre_D, var_list=gan.d_vars)
#判别器的梯度优化函数
d_optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.loss_D, var_list=gan.d_vars)
#预训练时的梯度优化函数
g_optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.loss_G, var_list=gan.g_vars) init = tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init)
#对判别器的预训练,训练了两个epoch
for i in range(2):
print('判别器初始训练,第' + str(i) + '次包')
for x_batch in iterate_minibatch(X, batch_size=batch_size):
loss_pre_D, _ = sess.run([gan.pre_logits, d_pre_optim],
feed_dict={
gan.x: x_batch
})
#训练5个epoch
for epoch in range(5):
print('对抗' + str(epoch) + '次包')
avg_loss = 0
count = 0
for x_batch in iterate_minibatch(X, batch_size=batch_size):
z_batch = np.random.uniform(-1, 1, size=(batch_size, 100)) # 随机起点值 loss_D, _ = sess.run([gan.loss_D, d_optim],
feed_dict={
gan.z: z_batch,
gan.x: x_batch
}) loss_G, _ = sess.run([gan.loss_G, g_optim],
feed_dict={
gan.z: z_batch,
# gan.x: np.zeros(z_batch.shape)
}) avg_loss += loss_D
count += 1 # 显示预测情况
if True:
avg_loss /= count
z = np.random.normal(size=(batch_size, 100))
excerpt = np.random.randint(100, size=batch_size)
needTest = np.reshape(X[excerpt, :], (-1, 28, 28, 1))
fake_x, real_logits, fake_logits = sess.run([gan.fake_x, gan.real_logits, gan.fake_logits],
feed_dict={gan.z: z, gan.x: needTest})
# accuracy = (np.sum(real_logits > 0.5) + np.sum(fake_logits < 0.5)) / (2 * batch_size)
print('real_logits')
print(len(real_logits))
print('fake_logits')
print(len(fake_logits))
print('\ndiscriminator loss at epoch %d: %f' % (epoch, avg_loss))
# print('\ndiscriminator accuracy at epoch %d: %f' % (epoch, accuracy))
print('----')
print() # curr_img = np.reshape(trainimg[i, :], (28, 28)) # 28 by 28 matrix
curr_img = np.reshape(fake_x[0], (28, 28))
plt.matshow(curr_img, cmap=plt.get_cmap('gray'))
plt.show()
curr_img2 = np.reshape(fake_x[10], (28, 28))
plt.matshow(curr_img2, cmap=plt.get_cmap('gray'))
plt.show()
curr_img3 = np.reshape(fake_x[20], (28, 28))
plt.matshow(curr_img3, cmap=plt.get_cmap('gray'))
plt.show() curr_img4 = np.reshape(fake_x[30], (28, 28))
plt.matshow(curr_img4, cmap=plt.get_cmap('gray'))
plt.show() curr_img5 = np.reshape(fake_x[40], (28, 28))
plt.matshow(curr_img5, cmap=plt.get_cmap('gray'))
plt.show()
# plt.figure(figsize=(28, 28)) # plt.title("" + str(i) + "th Training Data "
# + "Label is " + str(curr_label))
# print("" + str(i) + "th Training Data "
# + "Label is " + str(curr_label)) # plt.scatter(X[:, 0], X[:, 1])
# plt.scatter(fake_x[:, 0], fake_x[:, 1])
# plt.show()

结果

下载链接

GAN生成式对抗网络(三)——mnist数据生成的更多相关文章

  1. GAN生成式对抗网络(四)——SRGAN超高分辨率图片重构

    论文pdf 地址:https://arxiv.org/pdf/1609.04802v1.pdf 我的实际效果 清晰度距离我的期待有距离. 颜色上面存在差距. 解决想法 增加一个颜色判别器.将颜色值反馈 ...

  2. GAN生成式对抗网络(一)——原理

    生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型 GAN包括两个核心模块. 1.生成器模块 --generator 2.判别器模块--de ...

  3. 不要怂,就是GAN (生成式对抗网络) (一)

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  4. 不要怂,就是GAN (生成式对抗网络) (一): GAN 简介

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  5. 不要怂,就是GAN (生成式对抗网络) (二):数据读取和操作

    前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...

  6. 不要怂,就是GAN (生成式对抗网络) (二)

    前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...

  7. 不要怂,就是GAN (生成式对抗网络) (六):Wasserstein GAN(WGAN) TensorFlow 代码

    先来梳理一下我们之前所写的代码,原始的生成对抗网络,所要优化的目标函数为: 此目标函数可以分为两部分来看: ①固定生成器 G,优化判别器 D, 则上式可以写成如下形式: 可以转化为最小化形式: 我们编 ...

  8. 不要怂,就是GAN (生成式对抗网络) (三):判别器和生成器 TensorFlow Model

    在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 utils.py,输入如下代码: import scipy.misc import numpy as np # 保存 ...

  9. 不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN

    在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 train.py,同时新建文件夹 logs 和文件夹 samples,前者用来保存训练过程中的日志和模型,后者用来保 ...

随机推荐

  1. 学习嵌入式为什么要有uboot(深度解析)

    ref:http://www.elecfans.com/d/617674.html     为什么要有uboot 1.1.计算机系统的主要部件 (1)计算机系统就是以CPU为核心来运行的系统. 典型的 ...

  2. 谷歌大脑提出:基于NAS的目标检测模型NAS-FPN,超越Mask R-CNN

    谷歌大脑提出:基于NAS的目标检测模型NAS-FPN,超越Mask R-CNN 朱晓霞发表于目标检测和深度学习订阅 235 广告关闭 11.11 智慧上云 云服务器企业新用户优先购,享双11同等价格 ...

  3. 解决 org.apache.ibatis.binding.BindingException: Invalid bound statement (not found) 以及MyBatis批量加载xml映射文件的方式

    错误 org.apache.ibatis.binding.BindingException: Invalid bound statement (not found) 的出现,意味着项目需要xml文件来 ...

  4. Bootstrap3 CDN 使用手册

    一.一般功能 <link href="https://cdn.bootcss.com/bootstrap/3.3.7/css/bootstrap.css" rel=" ...

  5. vue引入警告:There are multiple modules with names that only differ in casing. This can lead to unexpected behavior when compiling on a filesystem with other case-semantic. Use equal casing. Compare these

    在写vue项目的时候 当我使用 : import dataSource from '../overseaProduct/house/dataSource'; 引入dataSource文件的时候:控制台 ...

  6. cmd查找端口占用情况

    查找端口占用情况:netstat -ano|findstr 4848 查看使用指定端口的应用程序:tasklist|findstr xxxx,xxxx指的是pid 结束指定进程:taskkill /p ...

  7. [JZOJ100019]A--dfn序+扫描线

    [JZOJ100019]A--dfn序+扫描线 题目链接 太懒了自行搜索 分析 这道题查了一个下午的错,真的心态崩了 不过这道题确实妙啊 类比于喝喝喝,我们发现任何一条覆盖了非法路径的路径一定不合法, ...

  8. python实现暴力破解

    import urllib2 import urllib import cookielib import threading import sys import Queue from HTMLPars ...

  9. 现有项目springmvc 小结

    1. 接口接收json数据 @RequestBody JSONObject param 2.返回json数据封装 DataPacket.jsonResult

  10. spring-boot-actuator 常用配置

    management: endpoints: web: base-path: "/" exposure: include: "*" endpoint: heal ...