在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 utils.py,输入如下代码:

import scipy.misc
import numpy as np # 保存图片函数
def save_images(images, size, path): """
Save the samples images
The best size number is
int(max(sqrt(image.shape[0]),sqrt(image.shape[1]))) + 1
example:
The batch_size is 64, then the size is recommended [8, 8]
The batch_size is 32, then the size is recommended [6, 6]
""" # 图片归一化,主要用于生成器输出是 tanh 形式的归一化
img = (images + 1.0) / 2.0
h, w = img.shape[1], img.shape[2] # 产生一个大画布,用来保存生成的 batch_size 个图像
merge_img = np.zeros((h * size[0], w * size[1], 3)) # 循环使得画布特定地方值为某一幅图像的值
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
merge_img[j*h:j*h+h, i*w:i*w+w, :] = image # 保存画布
return scipy.misc.imsave(path, merge_img)

这个函数的作用是在训练的过程中保存采样生成的图片。

在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 model.py,定义生成器,判别器和训练过程中的采样网络,在 model.py 输入如下代码:

import tensorflow as tf
from ops import * BATCH_SIZE = 64 # 定义生成器
def generator(z, y, train = True):
# y 是一个 [BATCH_SIZE, 10] 维的向量,把 y 转成四维张量
yb = tf.reshape(y, [BATCH_SIZE, 1, 1, 10], name = 'yb')
# 把 y 作为约束条件和 z 拼接起来
z = tf.concat(1, [z, y], name = 'z_concat_y')
# 经过一个全连接,BN 和激活层 ReLu
h1 = tf.nn.relu(batch_norm_layer(fully_connected(z, 1024, 'g_fully_connected1'),
is_train = train, name = 'g_bn1'))
# 把约束条件和上一层拼接起来
h1 = tf.concat(1, [h1, y], name = 'active1_concat_y') h2 = tf.nn.relu(batch_norm_layer(fully_connected(h1, 128 * 49, 'g_fully_connected2'),
is_train = train, name = 'g_bn2'))
h2 = tf.reshape(h2, [64, 7, 7, 128], name = 'h2_reshape')
# 把约束条件和上一层拼接起来
h2 = conv_cond_concat(h2, yb, name = 'active2_concat_y') h3 = tf.nn.relu(batch_norm_layer(deconv2d(h2, [64,14,14,128],
name = 'g_deconv2d3'),
is_train = train, name = 'g_bn3'))
h3 = conv_cond_concat(h3, yb, name = 'active3_concat_y') # 经过一个 sigmoid 函数把值归一化为 0~1 之间,
h4 = tf.nn.sigmoid(deconv2d(h3, [64, 28, 28, 1],
name = 'g_deconv2d4'), name = 'generate_image') return h4 # 定义判别器
def discriminator(image, y, reuse = False): # 因为真实数据和生成数据都要经过判别器,所以需要指定 reuse 是否可用
if reuse:
tf.get_variable_scope().reuse_variables() # 同生成器一样,判别器也需要把约束条件串联进来
yb = tf.reshape(y, [BATCH_SIZE, 1, 1, 10], name = 'yb')
x = conv_cond_concat(image, yb, name = 'image_concat_y') # 卷积,激活,串联条件。
h1 = lrelu(conv2d(x, 11, name = 'd_conv2d1'), name = 'lrelu1')
h1 = conv_cond_concat(h1, yb, name = 'h1_concat_yb') h2 = lrelu(batch_norm_layer(conv2d(h1, 74, name = 'd_conv2d2'),
name = 'd_bn2'), name = 'lrelu2')
h2 = tf.reshape(h2, [BATCH_SIZE, -1], name = 'reshape_lrelu2_to_2d')
h2 = tf.concat(1, [h2, y], name = 'lrelu2_concat_y') h3 = lrelu(batch_norm_layer(fully_connected(h2, 1024, name = 'd_fully_connected3'),
name = 'd_bn3'), name = 'lrelu3')
h3 = tf.concat(1,[h3, y], name = 'lrelu3_concat_y') # 全连接层,输出以为 loss 值
h4 = fully_connected(h3, 1, name = 'd_result_withouts_sigmoid') return tf.nn.sigmoid(h4, name = 'discriminator_result_with_sigmoid'), h4 # 定义训练过程中的采样函数
def sampler(z, y, train = True):
tf.get_variable_scope().reuse_variables()
return generator(z, y, train = train)

可以看到,生成器由 7 × 7  变为 14 × 14 再变为 28 × 28大小,每一层都加入了约束条件 y,完美的诠释了论文所给出的网络,之所以要加入 is_train 参数,是由于 Batch_norm 层中训练和测试的时候的过程是不同的,用这个参数区分训练和测试,生成器的最后一层,用了一个 sigmoid 函数把值归一化到 0~1 之间,如果是不加约束的网络,则用 tanh 函数,所以在 save_images 函数中要用到语句:img = (images + 1.0) / 2.0。

sampler 函数的作用是在训练过程中对生成器生成的图片进行采样,所以这个函数必须指定 reuse 可用,关于 reuse 说明,请看:http://www.cnblogs.com/Charles-Wan/p/6200446.html。

参考资料:

1. https://github.com/carpedm20/DCGAN-tensorflow

不要怂,就是GAN (生成式对抗网络) (三):判别器和生成器 TensorFlow Model的更多相关文章

  1. GAN生成式对抗网络(三)——mnist数据生成

    通过GAN生成式对抗网络,产生mnist数据 引入包,数据约定等 import numpy as np import matplotlib.pyplot as plt import input_dat ...

  2. GAN生成式对抗网络(四)——SRGAN超高分辨率图片重构

    论文pdf 地址:https://arxiv.org/pdf/1609.04802v1.pdf 我的实际效果 清晰度距离我的期待有距离. 颜色上面存在差距. 解决想法 增加一个颜色判别器.将颜色值反馈 ...

  3. GAN生成式对抗网络(一)——原理

    生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型 GAN包括两个核心模块. 1.生成器模块 --generator 2.判别器模块--de ...

  4. 不要怂,就是GAN (生成式对抗网络) (一)

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  5. 不要怂,就是GAN (生成式对抗网络) (一): GAN 简介

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  6. 不要怂,就是GAN (生成式对抗网络) (二)

    前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...

  7. 不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN

    在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 train.py,同时新建文件夹 logs 和文件夹 samples,前者用来保存训练过程中的日志和模型,后者用来保 ...

  8. 不要怂,就是GAN (生成式对抗网络) (二):数据读取和操作

    前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...

  9. 不要怂,就是GAN (生成式对抗网络) (六):Wasserstein GAN(WGAN) TensorFlow 代码

    先来梳理一下我们之前所写的代码,原始的生成对抗网络,所要优化的目标函数为: 此目标函数可以分为两部分来看: ①固定生成器 G,优化判别器 D, 则上式可以写成如下形式: 可以转化为最小化形式: 我们编 ...

随机推荐

  1. org.springframework.orm.hibernate3.HibernateTemplate

    当session中出现两个相同标示的(相同主键)的对象,一个是持久态,一个是瞬时态,想更新瞬时态对象到数据库,如果不做处理,则报出异常,session中出现两个相同标示的不同对象异常.处理方法.(业务 ...

  2. graphql cli 开发graphql api flow

    作用 代码生成 schema 处理 脚手架应用创建 项目管理 安装cli npm install -g graphql-cli 初始化项目(使用.graphqlconfig管理) 以下为demo de ...

  3. 使用nomad && consul && fabio 创建简单的微服务系统

    具体每个组件的功能就不详细说明了 nomad 一个调度工具,consul 一个服务发现,健康检查多数据中心支持的工具 fabio 一个基于consul的负载均衡&&动态路由工具,对于集 ...

  4. 不以main为入口的函数

    先看一段程序 #include <stdio.h> void test() { printf("Hello Word!\n"); return 0; } 没有main函 ...

  5. centos6.5下tomcat安装

    1.安装JDK 安装:rpm –ivh jdk-7u5-linux-i586.rpm2.配置Tomcat 解压:tar -zxvf apache-tomcat-8.0.11.tar.gz 移动到/us ...

  6. moodle搭建相关的笔记

    关于moodle内网外网访问问题的解决方案(转) http://blog.chinaunix.net/uid-656828-id-3106027.html

  7. python学习日志

    马上就中秋节,想着再学点新的知识,本来想去继续研究前端知识来着,但是内个烦人的样式css还有js搞的有点脑壳头,以后就主学后端吧,要去死了前端这条心了? 那么寻寻觅觅就入坑最近几年大热的python吧 ...

  8. java工具类-读配置文件

    ///读配置文件 import java.io.InputStream;import java.util.HashMap;import java.util.Map;import java.util.M ...

  9. Nginx+tomcat+redis 集群session共享

    插件资源下载地址:https://github.com/ran-jit/tomcat-cluster-redis-session-manager/releases/tag/2.0.2 一.前置条件 J ...

  10. java web 程序---javabean实例--登陆界面并显示用户名和密码

    重点:注意大小写,不注意细节,这点小事,还需要请教 发现一个问题,也是老师当时写的时候,发现代码没错,但是就是运行问题. 大家看,那个java类,我们要求是所有属性均为私有变量,但是方法为公有的,如果 ...