不要怂,就是GAN (生成式对抗网络) (三):判别器和生成器 TensorFlow Model
在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 utils.py,输入如下代码:
import scipy.misc
import numpy as np # 保存图片函数
def save_images(images, size, path): """
Save the samples images
The best size number is
int(max(sqrt(image.shape[0]),sqrt(image.shape[1]))) + 1
example:
The batch_size is 64, then the size is recommended [8, 8]
The batch_size is 32, then the size is recommended [6, 6]
""" # 图片归一化,主要用于生成器输出是 tanh 形式的归一化
img = (images + 1.0) / 2.0
h, w = img.shape[1], img.shape[2] # 产生一个大画布,用来保存生成的 batch_size 个图像
merge_img = np.zeros((h * size[0], w * size[1], 3)) # 循环使得画布特定地方值为某一幅图像的值
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
merge_img[j*h:j*h+h, i*w:i*w+w, :] = image # 保存画布
return scipy.misc.imsave(path, merge_img)
这个函数的作用是在训练的过程中保存采样生成的图片。
在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 model.py,定义生成器,判别器和训练过程中的采样网络,在 model.py 输入如下代码:
import tensorflow as tf
from ops import * BATCH_SIZE = 64 # 定义生成器
def generator(z, y, train = True):
# y 是一个 [BATCH_SIZE, 10] 维的向量,把 y 转成四维张量
yb = tf.reshape(y, [BATCH_SIZE, 1, 1, 10], name = 'yb')
# 把 y 作为约束条件和 z 拼接起来
z = tf.concat(1, [z, y], name = 'z_concat_y')
# 经过一个全连接,BN 和激活层 ReLu
h1 = tf.nn.relu(batch_norm_layer(fully_connected(z, 1024, 'g_fully_connected1'),
is_train = train, name = 'g_bn1'))
# 把约束条件和上一层拼接起来
h1 = tf.concat(1, [h1, y], name = 'active1_concat_y') h2 = tf.nn.relu(batch_norm_layer(fully_connected(h1, 128 * 49, 'g_fully_connected2'),
is_train = train, name = 'g_bn2'))
h2 = tf.reshape(h2, [64, 7, 7, 128], name = 'h2_reshape')
# 把约束条件和上一层拼接起来
h2 = conv_cond_concat(h2, yb, name = 'active2_concat_y') h3 = tf.nn.relu(batch_norm_layer(deconv2d(h2, [64,14,14,128],
name = 'g_deconv2d3'),
is_train = train, name = 'g_bn3'))
h3 = conv_cond_concat(h3, yb, name = 'active3_concat_y') # 经过一个 sigmoid 函数把值归一化为 0~1 之间,
h4 = tf.nn.sigmoid(deconv2d(h3, [64, 28, 28, 1],
name = 'g_deconv2d4'), name = 'generate_image') return h4 # 定义判别器
def discriminator(image, y, reuse = False): # 因为真实数据和生成数据都要经过判别器,所以需要指定 reuse 是否可用
if reuse:
tf.get_variable_scope().reuse_variables() # 同生成器一样,判别器也需要把约束条件串联进来
yb = tf.reshape(y, [BATCH_SIZE, 1, 1, 10], name = 'yb')
x = conv_cond_concat(image, yb, name = 'image_concat_y') # 卷积,激活,串联条件。
h1 = lrelu(conv2d(x, 11, name = 'd_conv2d1'), name = 'lrelu1')
h1 = conv_cond_concat(h1, yb, name = 'h1_concat_yb') h2 = lrelu(batch_norm_layer(conv2d(h1, 74, name = 'd_conv2d2'),
name = 'd_bn2'), name = 'lrelu2')
h2 = tf.reshape(h2, [BATCH_SIZE, -1], name = 'reshape_lrelu2_to_2d')
h2 = tf.concat(1, [h2, y], name = 'lrelu2_concat_y') h3 = lrelu(batch_norm_layer(fully_connected(h2, 1024, name = 'd_fully_connected3'),
name = 'd_bn3'), name = 'lrelu3')
h3 = tf.concat(1,[h3, y], name = 'lrelu3_concat_y') # 全连接层,输出以为 loss 值
h4 = fully_connected(h3, 1, name = 'd_result_withouts_sigmoid') return tf.nn.sigmoid(h4, name = 'discriminator_result_with_sigmoid'), h4 # 定义训练过程中的采样函数
def sampler(z, y, train = True):
tf.get_variable_scope().reuse_variables()
return generator(z, y, train = train)
可以看到,生成器由 7 × 7 变为 14 × 14 再变为 28 × 28大小,每一层都加入了约束条件 y,完美的诠释了论文所给出的网络,之所以要加入 is_train 参数,是由于 Batch_norm 层中训练和测试的时候的过程是不同的,用这个参数区分训练和测试,生成器的最后一层,用了一个 sigmoid 函数把值归一化到 0~1 之间,如果是不加约束的网络,则用 tanh 函数,所以在 save_images 函数中要用到语句:img = (images + 1.0) / 2.0。
sampler 函数的作用是在训练过程中对生成器生成的图片进行采样,所以这个函数必须指定 reuse 可用,关于 reuse 说明,请看:http://www.cnblogs.com/Charles-Wan/p/6200446.html。
参考资料:
1. https://github.com/carpedm20/DCGAN-tensorflow
不要怂,就是GAN (生成式对抗网络) (三):判别器和生成器 TensorFlow Model的更多相关文章
- GAN生成式对抗网络(三)——mnist数据生成
通过GAN生成式对抗网络,产生mnist数据 引入包,数据约定等 import numpy as np import matplotlib.pyplot as plt import input_dat ...
- GAN生成式对抗网络(四)——SRGAN超高分辨率图片重构
论文pdf 地址:https://arxiv.org/pdf/1609.04802v1.pdf 我的实际效果 清晰度距离我的期待有距离. 颜色上面存在差距. 解决想法 增加一个颜色判别器.将颜色值反馈 ...
- GAN生成式对抗网络(一)——原理
生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型 GAN包括两个核心模块. 1.生成器模块 --generator 2.判别器模块--de ...
- 不要怂,就是GAN (生成式对抗网络) (一)
前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...
- 不要怂,就是GAN (生成式对抗网络) (一): GAN 简介
前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...
- 不要怂,就是GAN (生成式对抗网络) (二)
前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...
- 不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN
在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 train.py,同时新建文件夹 logs 和文件夹 samples,前者用来保存训练过程中的日志和模型,后者用来保 ...
- 不要怂,就是GAN (生成式对抗网络) (二):数据读取和操作
前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...
- 不要怂,就是GAN (生成式对抗网络) (六):Wasserstein GAN(WGAN) TensorFlow 代码
先来梳理一下我们之前所写的代码,原始的生成对抗网络,所要优化的目标函数为: 此目标函数可以分为两部分来看: ①固定生成器 G,优化判别器 D, 则上式可以写成如下形式: 可以转化为最小化形式: 我们编 ...
随机推荐
- CH3401 石头游戏
题意 3401 石头游戏 0x30「数学知识」例题 描述 石头游戏在一个 n 行 m 列 (1≤n,m≤8) 的网格上进行,每个格子对应一种操作序列,操作序列至多有10种,分别用0~9这10个数字指明 ...
- Bootstrap-Plugin:警告框(Alert)插件
ylbtech-Bootstrap-Plugin:警告框(Alert)插件 1.返回顶部 1. Bootstrap 警告框(Alert)插件 警告框(Alert)消息大多是用来向终端用户显示诸如警告或 ...
- Linux常用系统函数
Linux常用系统函数 一.进程控制 fork 创建一个新进程clone 按指定条件创建子进程execve 运行可执行文件exit 中止进程_exit 立即中止当前进程getdtablesize 进程 ...
- Thinkphp 联表查询 表名要全名
我有2个表 表1. 表2 已知表2的user_id 查询满足 表2.wb_id=表1.id 表1的内容 最佳答案 i 2013年11月15日 $result = M()->table(array ...
- [置顶]
linux c常用函数 (待完善)
(1)字符测试函数 isalnum(测试字符是否为英文字母或数字) isalpha(测试字符是否为英文字母) isascii(测试字符是否为ASCII码字符) isblank(测试字符是否为空格字符) ...
- bootstrap插件的一些常用属性介绍
1.下拉菜单 <div class="dropdown"> <button class="btn btn-default dropdown-toggle ...
- Thymeleaf系列五 迭代,if,switch语法
1. 概述 这里介绍thymeleaf的编程语法,本节主要包括如下内容 迭代语法:th:each; iteration status 条件语法:th:if; th:unless switch语法: ...
- vuex语法精简(方便开发查阅)
vuex语法精简(方便开发查阅) store结构 state Getter Mutation actions vuex语法精简(方便开发查阅) 本文只是方便开发的时候快速查阅到相关语法,想看详细内容请 ...
- Day15-Django
all_entries = Entry.objects.all() #查询所有 Entry.objects.filter(pub_date__year=2006) #查询所有pub_date为2006 ...
- python推荐书籍
推荐的python电子书 python学习路线图 优先级 入门:python核心编程 提高:python cookbook 其他 (1).数据分析师 需要有深厚的数理统计基础,但是对程序开发能力不做要 ...