不要怂,就是GAN (生成式对抗网络) (三):判别器和生成器 TensorFlow Model
在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 utils.py,输入如下代码:
import scipy.misc
import numpy as np # 保存图片函数
def save_images(images, size, path): """
Save the samples images
The best size number is
int(max(sqrt(image.shape[0]),sqrt(image.shape[1]))) + 1
example:
The batch_size is 64, then the size is recommended [8, 8]
The batch_size is 32, then the size is recommended [6, 6]
""" # 图片归一化,主要用于生成器输出是 tanh 形式的归一化
img = (images + 1.0) / 2.0
h, w = img.shape[1], img.shape[2] # 产生一个大画布,用来保存生成的 batch_size 个图像
merge_img = np.zeros((h * size[0], w * size[1], 3)) # 循环使得画布特定地方值为某一幅图像的值
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
merge_img[j*h:j*h+h, i*w:i*w+w, :] = image # 保存画布
return scipy.misc.imsave(path, merge_img)
这个函数的作用是在训练的过程中保存采样生成的图片。
在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 model.py,定义生成器,判别器和训练过程中的采样网络,在 model.py 输入如下代码:
import tensorflow as tf
from ops import * BATCH_SIZE = 64 # 定义生成器
def generator(z, y, train = True):
# y 是一个 [BATCH_SIZE, 10] 维的向量,把 y 转成四维张量
yb = tf.reshape(y, [BATCH_SIZE, 1, 1, 10], name = 'yb')
# 把 y 作为约束条件和 z 拼接起来
z = tf.concat(1, [z, y], name = 'z_concat_y')
# 经过一个全连接,BN 和激活层 ReLu
h1 = tf.nn.relu(batch_norm_layer(fully_connected(z, 1024, 'g_fully_connected1'),
is_train = train, name = 'g_bn1'))
# 把约束条件和上一层拼接起来
h1 = tf.concat(1, [h1, y], name = 'active1_concat_y') h2 = tf.nn.relu(batch_norm_layer(fully_connected(h1, 128 * 49, 'g_fully_connected2'),
is_train = train, name = 'g_bn2'))
h2 = tf.reshape(h2, [64, 7, 7, 128], name = 'h2_reshape')
# 把约束条件和上一层拼接起来
h2 = conv_cond_concat(h2, yb, name = 'active2_concat_y') h3 = tf.nn.relu(batch_norm_layer(deconv2d(h2, [64,14,14,128],
name = 'g_deconv2d3'),
is_train = train, name = 'g_bn3'))
h3 = conv_cond_concat(h3, yb, name = 'active3_concat_y') # 经过一个 sigmoid 函数把值归一化为 0~1 之间,
h4 = tf.nn.sigmoid(deconv2d(h3, [64, 28, 28, 1],
name = 'g_deconv2d4'), name = 'generate_image') return h4 # 定义判别器
def discriminator(image, y, reuse = False): # 因为真实数据和生成数据都要经过判别器,所以需要指定 reuse 是否可用
if reuse:
tf.get_variable_scope().reuse_variables() # 同生成器一样,判别器也需要把约束条件串联进来
yb = tf.reshape(y, [BATCH_SIZE, 1, 1, 10], name = 'yb')
x = conv_cond_concat(image, yb, name = 'image_concat_y') # 卷积,激活,串联条件。
h1 = lrelu(conv2d(x, 11, name = 'd_conv2d1'), name = 'lrelu1')
h1 = conv_cond_concat(h1, yb, name = 'h1_concat_yb') h2 = lrelu(batch_norm_layer(conv2d(h1, 74, name = 'd_conv2d2'),
name = 'd_bn2'), name = 'lrelu2')
h2 = tf.reshape(h2, [BATCH_SIZE, -1], name = 'reshape_lrelu2_to_2d')
h2 = tf.concat(1, [h2, y], name = 'lrelu2_concat_y') h3 = lrelu(batch_norm_layer(fully_connected(h2, 1024, name = 'd_fully_connected3'),
name = 'd_bn3'), name = 'lrelu3')
h3 = tf.concat(1,[h3, y], name = 'lrelu3_concat_y') # 全连接层,输出以为 loss 值
h4 = fully_connected(h3, 1, name = 'd_result_withouts_sigmoid') return tf.nn.sigmoid(h4, name = 'discriminator_result_with_sigmoid'), h4 # 定义训练过程中的采样函数
def sampler(z, y, train = True):
tf.get_variable_scope().reuse_variables()
return generator(z, y, train = train)
可以看到,生成器由 7 × 7 变为 14 × 14 再变为 28 × 28大小,每一层都加入了约束条件 y,完美的诠释了论文所给出的网络,之所以要加入 is_train 参数,是由于 Batch_norm 层中训练和测试的时候的过程是不同的,用这个参数区分训练和测试,生成器的最后一层,用了一个 sigmoid 函数把值归一化到 0~1 之间,如果是不加约束的网络,则用 tanh 函数,所以在 save_images 函数中要用到语句:img = (images + 1.0) / 2.0。
sampler 函数的作用是在训练过程中对生成器生成的图片进行采样,所以这个函数必须指定 reuse 可用,关于 reuse 说明,请看:http://www.cnblogs.com/Charles-Wan/p/6200446.html。
参考资料:
1. https://github.com/carpedm20/DCGAN-tensorflow
不要怂,就是GAN (生成式对抗网络) (三):判别器和生成器 TensorFlow Model的更多相关文章
- GAN生成式对抗网络(三)——mnist数据生成
通过GAN生成式对抗网络,产生mnist数据 引入包,数据约定等 import numpy as np import matplotlib.pyplot as plt import input_dat ...
- GAN生成式对抗网络(四)——SRGAN超高分辨率图片重构
论文pdf 地址:https://arxiv.org/pdf/1609.04802v1.pdf 我的实际效果 清晰度距离我的期待有距离. 颜色上面存在差距. 解决想法 增加一个颜色判别器.将颜色值反馈 ...
- GAN生成式对抗网络(一)——原理
生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型 GAN包括两个核心模块. 1.生成器模块 --generator 2.判别器模块--de ...
- 不要怂,就是GAN (生成式对抗网络) (一)
前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...
- 不要怂,就是GAN (生成式对抗网络) (一): GAN 简介
前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...
- 不要怂,就是GAN (生成式对抗网络) (二)
前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...
- 不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN
在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 train.py,同时新建文件夹 logs 和文件夹 samples,前者用来保存训练过程中的日志和模型,后者用来保 ...
- 不要怂,就是GAN (生成式对抗网络) (二):数据读取和操作
前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...
- 不要怂,就是GAN (生成式对抗网络) (六):Wasserstein GAN(WGAN) TensorFlow 代码
先来梳理一下我们之前所写的代码,原始的生成对抗网络,所要优化的目标函数为: 此目标函数可以分为两部分来看: ①固定生成器 G,优化判别器 D, 则上式可以写成如下形式: 可以转化为最小化形式: 我们编 ...
随机推荐
- org.springframework.orm.hibernate3.HibernateTemplate
当session中出现两个相同标示的(相同主键)的对象,一个是持久态,一个是瞬时态,想更新瞬时态对象到数据库,如果不做处理,则报出异常,session中出现两个相同标示的不同对象异常.处理方法.(业务 ...
- graphql cli 开发graphql api flow
作用 代码生成 schema 处理 脚手架应用创建 项目管理 安装cli npm install -g graphql-cli 初始化项目(使用.graphqlconfig管理) 以下为demo de ...
- 使用nomad && consul && fabio 创建简单的微服务系统
具体每个组件的功能就不详细说明了 nomad 一个调度工具,consul 一个服务发现,健康检查多数据中心支持的工具 fabio 一个基于consul的负载均衡&&动态路由工具,对于集 ...
- 不以main为入口的函数
先看一段程序 #include <stdio.h> void test() { printf("Hello Word!\n"); return 0; } 没有main函 ...
- centos6.5下tomcat安装
1.安装JDK 安装:rpm –ivh jdk-7u5-linux-i586.rpm2.配置Tomcat 解压:tar -zxvf apache-tomcat-8.0.11.tar.gz 移动到/us ...
- moodle搭建相关的笔记
关于moodle内网外网访问问题的解决方案(转) http://blog.chinaunix.net/uid-656828-id-3106027.html
- python学习日志
马上就中秋节,想着再学点新的知识,本来想去继续研究前端知识来着,但是内个烦人的样式css还有js搞的有点脑壳头,以后就主学后端吧,要去死了前端这条心了? 那么寻寻觅觅就入坑最近几年大热的python吧 ...
- java工具类-读配置文件
///读配置文件 import java.io.InputStream;import java.util.HashMap;import java.util.Map;import java.util.M ...
- Nginx+tomcat+redis 集群session共享
插件资源下载地址:https://github.com/ran-jit/tomcat-cluster-redis-session-manager/releases/tag/2.0.2 一.前置条件 J ...
- java web 程序---javabean实例--登陆界面并显示用户名和密码
重点:注意大小写,不注意细节,这点小事,还需要请教 发现一个问题,也是老师当时写的时候,发现代码没错,但是就是运行问题. 大家看,那个java类,我们要求是所有属性均为私有变量,但是方法为公有的,如果 ...