前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条件的 GAN,和不加约束条件的GAN,我们先来搭建一个简单的 MNIST 数据集上加约束条件的 GAN。

首先下载数据:在  /home/your_name/TensorFlow/DCGAN/ 下建立文件夹 data/mnist,从 http://yann.lecun.com/exdb/mnist/ 网站上下载 mnist 数据集 train-images-idx3-ubyte.gztrain-labels-idx1-ubyte.gzt10k-images-idx3-ubyte.gzt10k-labels-idx1-ubyte.gz 到 mnist 文件夹下得到四个 .gz 文件。

数据下载好之后,在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 read_data.py 读取数据,输入如下代码:

import os
import numpy as np def read_data(): # 数据目录
data_dir = '/home/your_name/TensorFlow/DCGAN/data/mnist' # 打开训练数据
fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
# 转化成 numpy 数组
loaded = np.fromfile(file=fd,dtype=np.uint8)
# 根据 mnist 官网描述的数据格式,图像像素从 16 字节开始
trX = loaded[16:].reshape((60000,28,28,1)).astype(np.float) # 训练 label
fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
trY = loaded[8:].reshape((60000)).astype(np.float) # 测试数据
fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
teX = loaded[16:].reshape((10000,28,28,1)).astype(np.float) # 测试 label
fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
teY = loaded[8:].reshape((10000)).astype(np.float) trY = np.asarray(trY)
teY = np.asarray(teY) # 由于生成网络由服从某一分布的噪声生成图片,不需要测试集,
# 所以把训练和测试两部分数据合并
X = np.concatenate((trX, teX), axis=0)
y = np.concatenate((trY, teY), axis=0) # 打乱排序
seed = 547
np.random.seed(seed)
np.random.shuffle(X)
np.random.seed(seed)
np.random.shuffle(y) # 这里,y_vec 表示对网络所加的约束条件,这个条件是类别标签,
# 可以看到,y_vec 实际就是对 y 的独热编码,关于什么是独热编码,
# 请参考 http://www.cnblogs.com/Charles-Wan/p/6207039.html
y_vec = np.zeros((len(y), 10), dtype=np.float)
for i, label in enumerate(y):
y_vec[i,y[i]] = 1.0 return X/255., y_vec

这里顺便说明一下,由于 MNIST 数据总体占得内存不大(可以看下载的文件,最大的一个 45M 左右,)所以这样读取数据是允许的,一般情况下,数据特别庞大的时候,建议把数据转化成 tfrecords,用 TensorFlow 标准的数据读取格式,这样能带来比较高的效率。

然后,定义一些基本的操作层,例如卷积,池化,全连接等层,在 /home/your_name/TensorFlow/DCGAN/ 新建文件 ops.py,输入如下代码:

import tensorflow as tf
from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm # 常数偏置
def bias(name, shape, bias_start = 0.0, trainable = True): dtype = tf.float32
var = tf.get_variable(name, shape, tf.float32, trainable = trainable,
initializer = tf.constant_initializer(
bias_start, dtype = dtype))
return var # 随机权重
def weight(name, shape, stddev = 0.02, trainable = True): dtype = tf.float32
var = tf.get_variable(name, shape, tf.float32, trainable = trainable,
initializer = tf.random_normal_initializer(
stddev = stddev, dtype = dtype))
return var # 全连接层
def fully_connected(value, output_shape, name = 'fully_connected', with_w = False): shape = value.get_shape().as_list() with tf.variable_scope(name):
weights = weight('weights', [shape[1], output_shape], 0.02)
biases = bias('biases', [output_shape], 0.0) if with_w:
return tf.matmul(value, weights) + biases, weights, biases
else:
return tf.matmul(value, weights) + biases # Leaky-ReLu 层
def lrelu(x, leak=0.2, name = 'lrelu'): with tf.variable_scope(name):
return tf.maximum(x, leak*x, name = name) # ReLu 层
def relu(value, name = 'relu'):
with tf.variable_scope(name):
return tf.nn.relu(value) # 解卷积层
def deconv2d(value, output_shape, k_h = 5, k_w = 5, strides =[1, 2, 2, 1],
name = 'deconv2d', with_w = False): with tf.variable_scope(name):
weights = weight('weights',
[k_h, k_w, output_shape[-1], value.get_shape()[-1]])
deconv = tf.nn.conv2d_transpose(value, weights,
output_shape, strides = strides)
biases = bias('biases', [output_shape[-1]])
deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())
if with_w:
return deconv, weights, biases
else:
return deconv # 卷积层
def conv2d(value, output_dim, k_h = 5, k_w = 5,
strides =[1, 2, 2, 1], name = 'conv2d'): with tf.variable_scope(name):
weights = weight('weights',
[k_h, k_w, value.get_shape()[-1], output_dim])
conv = tf.nn.conv2d(value, weights, strides = strides, padding = 'SAME')
biases = bias('biases', [output_dim])
conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) return conv # 把约束条件串联到 feature map
def conv_cond_concat(value, cond, name = 'concat'): # 把张量的维度形状转化成 Python 的 list
value_shapes = value.get_shape().as_list()
cond_shapes = cond.get_shape().as_list() # 在第三个维度上(feature map 维度上)把条件和输入串联起来,
# 条件会被预先设为四维张量的形式,假设输入为 [64, 32, 32, 32] 维的张量,
# 条件为 [64, 32, 32, 10] 维的张量,那么输出就是一个 [64, 32, 32, 42] 维张量
with tf.variable_scope(name):
return tf.concat(3, [value,
cond * tf.ones(value_shapes[0:3] + cond_shapes[3:])]) # Batch Normalization 层
def batch_norm_layer(value, is_train = True, name = 'batch_norm'): with tf.variable_scope(name) as scope:
if is_train:
return batch_norm(value, decay = 0.9, epsilon = 1e-5, scale = True,
is_training = is_train,
updates_collections = None, scope = scope)
else:
return batch_norm(value, decay = 0.9, epsilon = 1e-5, scale = True,
is_training = is_train, reuse = True,
updates_collections = None, scope = scope)

TensorFlow 里使用 Batch Normalization 层,有很多种方法,这里我们直接使用官方 contrib 里面的层,其中 decay 指的是滑动平均的 decay,epsilon 作用是加到分母 variance 上避免分母为零,scale 是个布尔变量,如果为真值 True, 结果要乘以 gamma,否则 gamma 不使用,is_train 也是布尔变量,为真值代表训练过程,否则代表测试过程(在 BN 层中,训练过程和测试过程是不同的,具体请参考论文:https://arxiv.org/abs/1502.03167)。关于 batch_norm 的其他的参数,请看参考文献2。

参考文献:

1. https://github.com/carpedm20/DCGAN-tensorflow

2. https://github.com/tensorflow/tensorflow/blob/b826b79718e3e93148c3545e7aa3f90891744cc0/tensorflow/contrib/layers/python/layers/layers.py#L100

不要怂,就是GAN (生成式对抗网络) (二):数据读取和操作的更多相关文章

  1. GAN生成式对抗网络(四)——SRGAN超高分辨率图片重构

    论文pdf 地址:https://arxiv.org/pdf/1609.04802v1.pdf 我的实际效果 清晰度距离我的期待有距离. 颜色上面存在差距. 解决想法 增加一个颜色判别器.将颜色值反馈 ...

  2. GAN生成式对抗网络(三)——mnist数据生成

    通过GAN生成式对抗网络,产生mnist数据 引入包,数据约定等 import numpy as np import matplotlib.pyplot as plt import input_dat ...

  3. GAN生成式对抗网络(一)——原理

    生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型 GAN包括两个核心模块. 1.生成器模块 --generator 2.判别器模块--de ...

  4. 不要怂,就是GAN (生成式对抗网络) (一)

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  5. 不要怂,就是GAN (生成式对抗网络) (一): GAN 简介

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  6. 不要怂,就是GAN (生成式对抗网络) (二)

    前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...

  7. 不要怂,就是GAN (生成式对抗网络) (六):Wasserstein GAN(WGAN) TensorFlow 代码

    先来梳理一下我们之前所写的代码,原始的生成对抗网络,所要优化的目标函数为: 此目标函数可以分为两部分来看: ①固定生成器 G,优化判别器 D, 则上式可以写成如下形式: 可以转化为最小化形式: 我们编 ...

  8. 不要怂,就是GAN (生成式对抗网络) (五):无约束条件的 GAN 代码与网络的 Graph

    GAN 这个领域发展太快,日新月异,各种 GAN 层出不穷,前几天看到一篇关于 Wasserstein GAN 的文章,讲的很好,在此把它分享出来一起学习:https://zhuanlan.zhihu ...

  9. 不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN

    在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 train.py,同时新建文件夹 logs 和文件夹 samples,前者用来保存训练过程中的日志和模型,后者用来保 ...

随机推荐

  1. chaos-engineering 的一些开源工具

    Chaos Monkey - A resiliency tool that helps applications tolerate random instance failures. The Simi ...

  2. Oracle SEQUENCE 具体说明

     ORACLE  SEQUENCE     ORACLE没有自增数据类型,如需生成业务无关的主键列或惟一约束列,能够用sequence序列实现. CREATE SEQUENCE语句及參数介绍: 创建序 ...

  3. MPEG2-TS音视频同步原理

    一.引言MPEG2系统用于视音频同步以及系统时钟恢复的时间标签分别在ES,PES和TS这3个层次中.  在TS 层, TS头信息包含了节目时钟参考PCR(Program Clock Reference ...

  4. debian的甘特图工具

    sudo apt-get install planner

  5. 【HDU】4352 XHXJ's LIS(数位dp+状压)

    题目 传送门:QWQ 分析 数位dp 状压一下现在的$ O(nlogn) $的$ LIS $的二分数组 数据小,所以更新时直接暴力不用二分了. 代码 #include <bits/stdc++. ...

  6. Web服务的实质介绍

    web应用的实质 众所周知,对于所有的Web应用,本质上其实就是一个socket服务端,用户的浏览器其实就是一个socket客户端 Web应用程序是一种可以通过Web访问的应用程序,程序的最大好处是用 ...

  7. mysql数据导入的时候提示Got a packet bigger than 'max_allowed_packet' bytes

    Got a packet bigger than 'max_allowed_packet' bytes错误 默认可能是2M 把max_allowed_packet设置大于5M试试,我设置为160M,输 ...

  8. OpenCL 直方图

    ▶ 计算直方图,由原子计数和规约计算两部分组成 ● 最简单的版本,代码 // kernel.cl #pragma OPENCL EXTENSION cl_khr_local_int32_base_at ...

  9. js格式化时间 js格式化时间戳

    一个js格式化时间和js格式化时间戳的例子. 代码:/** * 时间对象的格式化; */Date.prototype.format = function(format) { /* * eg:forma ...

  10. pom指定java编译版本和编码

    方法一 <properties> <!-- 文件拷贝时的编码 --> <project.build.sourceEncoding>UTF-8</project ...