前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条件的 GAN,和不加约束条件的GAN,我们先来搭建一个简单的 MNIST 数据集上加约束条件的 GAN。

首先下载数据:在  /home/your_name/TensorFlow/DCGAN/ 下建立文件夹 data/mnist,从 http://yann.lecun.com/exdb/mnist/ 网站上下载 mnist 数据集 train-images-idx3-ubyte.gztrain-labels-idx1-ubyte.gzt10k-images-idx3-ubyte.gzt10k-labels-idx1-ubyte.gz 到 mnist 文件夹下得到四个 .gz 文件。

数据下载好之后,在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 read_data.py 读取数据,输入如下代码:

import os
import numpy as np def read_data(): # 数据目录
data_dir = '/home/your_name/TensorFlow/DCGAN/data/mnist' # 打开训练数据
fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
# 转化成 numpy 数组
loaded = np.fromfile(file=fd,dtype=np.uint8)
# 根据 mnist 官网描述的数据格式,图像像素从 16 字节开始
trX = loaded[16:].reshape((60000,28,28,1)).astype(np.float) # 训练 label
fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
trY = loaded[8:].reshape((60000)).astype(np.float) # 测试数据
fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
teX = loaded[16:].reshape((10000,28,28,1)).astype(np.float) # 测试 label
fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
teY = loaded[8:].reshape((10000)).astype(np.float) trY = np.asarray(trY)
teY = np.asarray(teY) # 由于生成网络由服从某一分布的噪声生成图片,不需要测试集,
# 所以把训练和测试两部分数据合并
X = np.concatenate((trX, teX), axis=0)
y = np.concatenate((trY, teY), axis=0) # 打乱排序
seed = 547
np.random.seed(seed)
np.random.shuffle(X)
np.random.seed(seed)
np.random.shuffle(y) # 这里,y_vec 表示对网络所加的约束条件,这个条件是类别标签,
# 可以看到,y_vec 实际就是对 y 的独热编码,关于什么是独热编码,
# 请参考 http://www.cnblogs.com/Charles-Wan/p/6207039.html
y_vec = np.zeros((len(y), 10), dtype=np.float)
for i, label in enumerate(y):
y_vec[i,y[i]] = 1.0 return X/255., y_vec

这里顺便说明一下,由于 MNIST 数据总体占得内存不大(可以看下载的文件,最大的一个 45M 左右,)所以这样读取数据是允许的,一般情况下,数据特别庞大的时候,建议把数据转化成 tfrecords,用 TensorFlow 标准的数据读取格式,这样能带来比较高的效率。

然后,定义一些基本的操作层,例如卷积,池化,全连接等层,在 /home/your_name/TensorFlow/DCGAN/ 新建文件 ops.py,输入如下代码:

import tensorflow as tf
from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm # 常数偏置
def bias(name, shape, bias_start = 0.0, trainable = True): dtype = tf.float32
var = tf.get_variable(name, shape, tf.float32, trainable = trainable,
initializer = tf.constant_initializer(
bias_start, dtype = dtype))
return var # 随机权重
def weight(name, shape, stddev = 0.02, trainable = True): dtype = tf.float32
var = tf.get_variable(name, shape, tf.float32, trainable = trainable,
initializer = tf.random_normal_initializer(
stddev = stddev, dtype = dtype))
return var # 全连接层
def fully_connected(value, output_shape, name = 'fully_connected', with_w = False): shape = value.get_shape().as_list() with tf.variable_scope(name):
weights = weight('weights', [shape[1], output_shape], 0.02)
biases = bias('biases', [output_shape], 0.0) if with_w:
return tf.matmul(value, weights) + biases, weights, biases
else:
return tf.matmul(value, weights) + biases # Leaky-ReLu 层
def lrelu(x, leak=0.2, name = 'lrelu'): with tf.variable_scope(name):
return tf.maximum(x, leak*x, name = name) # ReLu 层
def relu(value, name = 'relu'):
with tf.variable_scope(name):
return tf.nn.relu(value) # 解卷积层
def deconv2d(value, output_shape, k_h = 5, k_w = 5, strides =[1, 2, 2, 1],
name = 'deconv2d', with_w = False): with tf.variable_scope(name):
weights = weight('weights',
[k_h, k_w, output_shape[-1], value.get_shape()[-1]])
deconv = tf.nn.conv2d_transpose(value, weights,
output_shape, strides = strides)
biases = bias('biases', [output_shape[-1]])
deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())
if with_w:
return deconv, weights, biases
else:
return deconv # 卷积层
def conv2d(value, output_dim, k_h = 5, k_w = 5,
strides =[1, 2, 2, 1], name = 'conv2d'): with tf.variable_scope(name):
weights = weight('weights',
[k_h, k_w, value.get_shape()[-1], output_dim])
conv = tf.nn.conv2d(value, weights, strides = strides, padding = 'SAME')
biases = bias('biases', [output_dim])
conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) return conv # 把约束条件串联到 feature map
def conv_cond_concat(value, cond, name = 'concat'): # 把张量的维度形状转化成 Python 的 list
value_shapes = value.get_shape().as_list()
cond_shapes = cond.get_shape().as_list() # 在第三个维度上(feature map 维度上)把条件和输入串联起来,
# 条件会被预先设为四维张量的形式,假设输入为 [64, 32, 32, 32] 维的张量,
# 条件为 [64, 32, 32, 10] 维的张量,那么输出就是一个 [64, 32, 32, 42] 维张量
with tf.variable_scope(name):
return tf.concat(3, [value,
cond * tf.ones(value_shapes[0:3] + cond_shapes[3:])]) # Batch Normalization 层
def batch_norm_layer(value, is_train = True, name = 'batch_norm'): with tf.variable_scope(name) as scope:
if is_train:
return batch_norm(value, decay = 0.9, epsilon = 1e-5, scale = True,
is_training = is_train,
updates_collections = None, scope = scope)
else:
return batch_norm(value, decay = 0.9, epsilon = 1e-5, scale = True,
is_training = is_train, reuse = True,
updates_collections = None, scope = scope)

TensorFlow 里使用 Batch Normalization 层,有很多种方法,这里我们直接使用官方 contrib 里面的层,其中 decay 指的是滑动平均的 decay,epsilon 作用是加到分母 variance 上避免分母为零,scale 是个布尔变量,如果为真值 True, 结果要乘以 gamma,否则 gamma 不使用,is_train 也是布尔变量,为真值代表训练过程,否则代表测试过程(在 BN 层中,训练过程和测试过程是不同的,具体请参考论文:https://arxiv.org/abs/1502.03167)。关于 batch_norm 的其他的参数,请看参考文献2。

参考文献:

1. https://github.com/carpedm20/DCGAN-tensorflow

2. https://github.com/tensorflow/tensorflow/blob/b826b79718e3e93148c3545e7aa3f90891744cc0/tensorflow/contrib/layers/python/layers/layers.py#L100

不要怂,就是GAN (生成式对抗网络) (二):数据读取和操作的更多相关文章

  1. GAN生成式对抗网络(四)——SRGAN超高分辨率图片重构

    论文pdf 地址:https://arxiv.org/pdf/1609.04802v1.pdf 我的实际效果 清晰度距离我的期待有距离. 颜色上面存在差距. 解决想法 增加一个颜色判别器.将颜色值反馈 ...

  2. GAN生成式对抗网络(三)——mnist数据生成

    通过GAN生成式对抗网络,产生mnist数据 引入包,数据约定等 import numpy as np import matplotlib.pyplot as plt import input_dat ...

  3. GAN生成式对抗网络(一)——原理

    生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型 GAN包括两个核心模块. 1.生成器模块 --generator 2.判别器模块--de ...

  4. 不要怂,就是GAN (生成式对抗网络) (一)

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  5. 不要怂,就是GAN (生成式对抗网络) (一): GAN 简介

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  6. 不要怂,就是GAN (生成式对抗网络) (二)

    前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...

  7. 不要怂,就是GAN (生成式对抗网络) (六):Wasserstein GAN(WGAN) TensorFlow 代码

    先来梳理一下我们之前所写的代码,原始的生成对抗网络,所要优化的目标函数为: 此目标函数可以分为两部分来看: ①固定生成器 G,优化判别器 D, 则上式可以写成如下形式: 可以转化为最小化形式: 我们编 ...

  8. 不要怂,就是GAN (生成式对抗网络) (五):无约束条件的 GAN 代码与网络的 Graph

    GAN 这个领域发展太快,日新月异,各种 GAN 层出不穷,前几天看到一篇关于 Wasserstein GAN 的文章,讲的很好,在此把它分享出来一起学习:https://zhuanlan.zhihu ...

  9. 不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN

    在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 train.py,同时新建文件夹 logs 和文件夹 samples,前者用来保存训练过程中的日志和模型,后者用来保 ...

随机推荐

  1. test20181018 B君的第一题

    题意 分析 考场爆零做法 考虑dp,用\(f(i,j,0/1)\)表示i及其子树中形成j个边连通块的方案数,其中i是否向外连边. \(O(n^3)\),转移方程太复杂就打挂了. #include< ...

  2. 理解JAVA虚拟机(上)

    2016-04-16 23:10:50 在这里,我们进一步认识JAVA及JAVA虚拟机,包括它的体系结构和垃圾回收机制等,并通过虚拟机监控工具进行简单的性能调优. 一. JVM相关概念        ...

  3. CentOS 6安装php加速软件Zend Guard(转)

    (尚未验证) PHP5.3以上的版本不再支持Zend Optimizer,已经被全新的 Zend Guard Loader 取代,下面是安装Zend Guard具体步骤,以下操作均在终端命令行执行 1 ...

  4. fatal error: mysql.h: No such file or directory

    在ubuntu系统下安装mysql之后,和数据库连接的时候,出现如下错误:fatal error: mysql.h: No such file or directory 是因为缺少链接库,执行如下命名 ...

  5. jQuery layer弹出层插件 http://layer.layui.com/直接上官网学

    在许多网站中,经常用到弹出层,有时候为了达到更好的用户体验,你将写繁琐的css跟js,这款 jquery-layer可以让你想到即可做到的web弹窗/层js组件.layer侧重于用户灵活的自定义,为不 ...

  6. 七.jQuery源码解析之.toArray()

    toArray()是将jQuery对象转换成数组 从源码中可以看到,这些常见的方法,都是直接从原生的 javascript中"借鉴"过来的.为什么这么说呢? 225行中,在运行时, ...

  7. java8时间使用小结

    //LocalDate代表一个IOS格式(yyyy-MM-dd)的日期 获取当前的日期: LocalDate localDate = LocalDate.now();//LocalDate: 表示没有 ...

  8. VueRouter

    使用VueRouter的前提: 1, 必须导入vue-router.js文件    2, 要有VueRouter()实例    3, 要把VueRouter实例挂载到Vue实例中 4, 路由的入口   ...

  9. IPv4正则表达式匹配

    IP地址的长度为32位,分为4段,每段8位.用十进制数字表示,每段数字范围为0~255,段与段之间用英文句点“.”隔开.例如:某台计算机IP地址为111.22.33.4. 分析IP地址的组成特点:25 ...

  10. Django---Xss过滤以及单例模式

    Xss过滤 在表单填写的过程中我们就用到textarea,富文本编辑框,里面要用户输入相关的内容.如果有的人想要搞怪,在里面写一些js代码或者修改编辑的时候修改源代码,那提交上去之后就会使得页面显示不 ...