前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条件的 GAN,和不加约束条件的GAN,我们先来搭建一个简单的 MNIST 数据集上加约束条件的 GAN。

首先下载数据:在  /home/your_name/TensorFlow/DCGAN/ 下建立文件夹 data/mnist,从 http://yann.lecun.com/exdb/mnist/ 网站上下载 mnist 数据集 train-images-idx3-ubyte.gztrain-labels-idx1-ubyte.gzt10k-images-idx3-ubyte.gzt10k-labels-idx1-ubyte.gz 到 mnist 文件夹下得到四个 .gz 文件。

数据下载好之后,在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 read_data.py 读取数据,输入如下代码:

import os
import numpy as np def read_data(): # 数据目录
data_dir = '/home/your_name/TensorFlow/DCGAN/data/mnist' # 打开训练数据
fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
# 转化成 numpy 数组
loaded = np.fromfile(file=fd,dtype=np.uint8)
# 根据 mnist 官网描述的数据格式,图像像素从 16 字节开始
trX = loaded[16:].reshape((60000,28,28,1)).astype(np.float) # 训练 label
fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
trY = loaded[8:].reshape((60000)).astype(np.float) # 测试数据
fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
teX = loaded[16:].reshape((10000,28,28,1)).astype(np.float) # 测试 label
fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
teY = loaded[8:].reshape((10000)).astype(np.float) trY = np.asarray(trY)
teY = np.asarray(teY) # 由于生成网络由服从某一分布的噪声生成图片,不需要测试集,
# 所以把训练和测试两部分数据合并
X = np.concatenate((trX, teX), axis=0)
y = np.concatenate((trY, teY), axis=0) # 打乱排序
seed = 547
np.random.seed(seed)
np.random.shuffle(X)
np.random.seed(seed)
np.random.shuffle(y) # 这里,y_vec 表示对网络所加的约束条件,这个条件是类别标签,
# 可以看到,y_vec 实际就是对 y 的独热编码,关于什么是独热编码,
# 请参考 http://www.cnblogs.com/Charles-Wan/p/6207039.html
y_vec = np.zeros((len(y), 10), dtype=np.float)
for i, label in enumerate(y):
y_vec[i,y[i]] = 1.0 return X/255., y_vec

这里顺便说明一下,由于 MNIST 数据总体占得内存不大(可以看下载的文件,最大的一个 45M 左右,)所以这样读取数据是允许的,一般情况下,数据特别庞大的时候,建议把数据转化成 tfrecords,用 TensorFlow 标准的数据读取格式,这样能带来比较高的效率。

然后,定义一些基本的操作层,例如卷积,池化,全连接等层,在 /home/your_name/TensorFlow/DCGAN/ 新建文件 ops.py,输入如下代码:

import tensorflow as tf
from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm # 常数偏置
def bias(name, shape, bias_start = 0.0, trainable = True): dtype = tf.float32
var = tf.get_variable(name, shape, tf.float32, trainable = trainable,
initializer = tf.constant_initializer(
bias_start, dtype = dtype))
return var # 随机权重
def weight(name, shape, stddev = 0.02, trainable = True): dtype = tf.float32
var = tf.get_variable(name, shape, tf.float32, trainable = trainable,
initializer = tf.random_normal_initializer(
stddev = stddev, dtype = dtype))
return var # 全连接层
def fully_connected(value, output_shape, name = 'fully_connected', with_w = False): shape = value.get_shape().as_list() with tf.variable_scope(name):
weights = weight('weights', [shape[1], output_shape], 0.02)
biases = bias('biases', [output_shape], 0.0) if with_w:
return tf.matmul(value, weights) + biases, weights, biases
else:
return tf.matmul(value, weights) + biases # Leaky-ReLu 层
def lrelu(x, leak=0.2, name = 'lrelu'): with tf.variable_scope(name):
return tf.maximum(x, leak*x, name = name) # ReLu 层
def relu(value, name = 'relu'):
with tf.variable_scope(name):
return tf.nn.relu(value) # 解卷积层
def deconv2d(value, output_shape, k_h = 5, k_w = 5, strides =[1, 2, 2, 1],
name = 'deconv2d', with_w = False): with tf.variable_scope(name):
weights = weight('weights',
[k_h, k_w, output_shape[-1], value.get_shape()[-1]])
deconv = tf.nn.conv2d_transpose(value, weights,
output_shape, strides = strides)
biases = bias('biases', [output_shape[-1]])
deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())
if with_w:
return deconv, weights, biases
else:
return deconv # 卷积层
def conv2d(value, output_dim, k_h = 5, k_w = 5,
strides =[1, 2, 2, 1], name = 'conv2d'): with tf.variable_scope(name):
weights = weight('weights',
[k_h, k_w, value.get_shape()[-1], output_dim])
conv = tf.nn.conv2d(value, weights, strides = strides, padding = 'SAME')
biases = bias('biases', [output_dim])
conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) return conv # 把约束条件串联到 feature map
def conv_cond_concat(value, cond, name = 'concat'): # 把张量的维度形状转化成 Python 的 list
value_shapes = value.get_shape().as_list()
cond_shapes = cond.get_shape().as_list() # 在第三个维度上(feature map 维度上)把条件和输入串联起来,
# 条件会被预先设为四维张量的形式,假设输入为 [64, 32, 32, 32] 维的张量,
# 条件为 [64, 32, 32, 10] 维的张量,那么输出就是一个 [64, 32, 32, 42] 维张量
with tf.variable_scope(name):
return tf.concat(3, [value,
cond * tf.ones(value_shapes[0:3] + cond_shapes[3:])]) # Batch Normalization 层
def batch_norm_layer(value, is_train = True, name = 'batch_norm'): with tf.variable_scope(name) as scope:
if is_train:
return batch_norm(value, decay = 0.9, epsilon = 1e-5, scale = True,
is_training = is_train,
updates_collections = None, scope = scope)
else:
return batch_norm(value, decay = 0.9, epsilon = 1e-5, scale = True,
is_training = is_train, reuse = True,
updates_collections = None, scope = scope)

TensorFlow 里使用 Batch Normalization 层,有很多种方法,这里我们直接使用官方 contrib 里面的层,其中 decay 指的是滑动平均的 decay,epsilon 作用是加到分母 variance 上避免分母为零,scale 是个布尔变量,如果为真值 True, 结果要乘以 gamma,否则 gamma 不使用,is_train 也是布尔变量,为真值代表训练过程,否则代表测试过程(在 BN 层中,训练过程和测试过程是不同的,具体请参考论文:https://arxiv.org/abs/1502.03167)。关于 batch_norm 的其他的参数,请看参考文献2。

参考文献:

1. https://github.com/carpedm20/DCGAN-tensorflow

2. https://github.com/tensorflow/tensorflow/blob/b826b79718e3e93148c3545e7aa3f90891744cc0/tensorflow/contrib/layers/python/layers/layers.py#L100

不要怂,就是GAN (生成式对抗网络) (二):数据读取和操作的更多相关文章

  1. GAN生成式对抗网络(四)——SRGAN超高分辨率图片重构

    论文pdf 地址:https://arxiv.org/pdf/1609.04802v1.pdf 我的实际效果 清晰度距离我的期待有距离. 颜色上面存在差距. 解决想法 增加一个颜色判别器.将颜色值反馈 ...

  2. GAN生成式对抗网络(三)——mnist数据生成

    通过GAN生成式对抗网络,产生mnist数据 引入包,数据约定等 import numpy as np import matplotlib.pyplot as plt import input_dat ...

  3. GAN生成式对抗网络(一)——原理

    生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型 GAN包括两个核心模块. 1.生成器模块 --generator 2.判别器模块--de ...

  4. 不要怂,就是GAN (生成式对抗网络) (一)

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  5. 不要怂,就是GAN (生成式对抗网络) (一): GAN 简介

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  6. 不要怂,就是GAN (生成式对抗网络) (二)

    前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...

  7. 不要怂,就是GAN (生成式对抗网络) (六):Wasserstein GAN(WGAN) TensorFlow 代码

    先来梳理一下我们之前所写的代码,原始的生成对抗网络,所要优化的目标函数为: 此目标函数可以分为两部分来看: ①固定生成器 G,优化判别器 D, 则上式可以写成如下形式: 可以转化为最小化形式: 我们编 ...

  8. 不要怂,就是GAN (生成式对抗网络) (五):无约束条件的 GAN 代码与网络的 Graph

    GAN 这个领域发展太快,日新月异,各种 GAN 层出不穷,前几天看到一篇关于 Wasserstein GAN 的文章,讲的很好,在此把它分享出来一起学习:https://zhuanlan.zhihu ...

  9. 不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN

    在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 train.py,同时新建文件夹 logs 和文件夹 samples,前者用来保存训练过程中的日志和模型,后者用来保 ...

随机推荐

  1. 【thrift】thrift详解

    转载:http://zheming.wang/thrift-rpcxiang-jie.html Thrift Thrift是一个跨语言的服务部署框架,最初由Facebook于2007年开发,2008年 ...

  2. StreamSets Data Collector Edge 说明

    Data Collector Edge 是不包含界面的agent 安装 下载包 https://streamsets.com/opensource tar xf streamsets-datacoll ...

  3. Javascript模块化编程require.js的用法

    JS模块化工具requirejs教程(一):初识requirejs http://www.runoob.com/w3cnote/requirejs-tutorial-1.html JS模块化工具req ...

  4. Spring Could与Dubbo、Docker、K8S

    如果你是在一个中小型项目中应用Spring Cloud,那么你不需要太多的改造和适配,就可以实现微服务的基本功能.但是如果是在大型项目中实践微服务,可能会发现需要处理的问题还是比较多,尤其是项目中老代 ...

  5. 3——FFMPEG之解复用器-----AVInputFormat(转)

    1. 数据结构: AVInputFormat为FFMPEG的解复用器对象,通过调用av_register_all(),FFMPEG所有的解复用器保存在以first_iformat为链表头的链表中,且还 ...

  6. POJ2226Muddy Fields

    题目:http://poj.org/problem?id=2226 巧妙建图:以行或列上的联通块作为点,每个泥格子作为边,求最小点覆盖就可以了! 于是用匈牙利算法找最大匹配.注意要对右部点记录每一个左 ...

  7. MyEclipse部署项目到Tomcat上,但是classes文件夹下没有编译项目

    在MyEclipse中把项目部署到Tomcat上,但是Tomcat下的classes文件夹下没有编译项目解决方法:1-直接在点击菜单栏的Project--clean,对项目进行clean2-查看菜单栏 ...

  8. cowboy添加验证码

    参考的http://beebole.com/blog/erlang/how-to-implement-captcha-in-erlang-web-application/,移到cowboy,修改了下: ...

  9. HDU 2143 Can you find it?(基础二分)

    Can you find it? Time Limit: 10000/3000 MS (Java/Others)    Memory Limit: 32768/10000 K (Java/Others ...

  10. Django将.csv文件(excel文件)显示到网页上

    今天,我成功将项目要导入的测试数据导入并呈现了,虽然还不是很完美,但我之后仍会继续改进. 1.首先在主页面上加一个超链接按钮: 其它的不需要管,其它是我的另一个项目,没什么大用的 2.之后配置URL: ...