在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 utils.py,输入如下代码:

import scipy.misc
import numpy as np # 保存图片函数
def save_images(images, size, path): """
Save the samples images
The best size number is
int(max(sqrt(image.shape[0]),sqrt(image.shape[1]))) + 1
example:
The batch_size is 64, then the size is recommended [8, 8]
The batch_size is 32, then the size is recommended [6, 6]
""" # 图片归一化,主要用于生成器输出是 tanh 形式的归一化
img = (images + 1.0) / 2.0
h, w = img.shape[1], img.shape[2] # 产生一个大画布,用来保存生成的 batch_size 个图像
merge_img = np.zeros((h * size[0], w * size[1], 3)) # 循环使得画布特定地方值为某一幅图像的值
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
merge_img[j*h:j*h+h, i*w:i*w+w, :] = image # 保存画布
return scipy.misc.imsave(path, merge_img)

这个函数的作用是在训练的过程中保存采样生成的图片。

在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 model.py,定义生成器,判别器和训练过程中的采样网络,在 model.py 输入如下代码:

import tensorflow as tf
from ops import * BATCH_SIZE = 64 # 定义生成器
def generator(z, y, train = True):
# y 是一个 [BATCH_SIZE, 10] 维的向量,把 y 转成四维张量
yb = tf.reshape(y, [BATCH_SIZE, 1, 1, 10], name = 'yb')
# 把 y 作为约束条件和 z 拼接起来
z = tf.concat(1, [z, y], name = 'z_concat_y')
# 经过一个全连接,BN 和激活层 ReLu
h1 = tf.nn.relu(batch_norm_layer(fully_connected(z, 1024, 'g_fully_connected1'),
is_train = train, name = 'g_bn1'))
# 把约束条件和上一层拼接起来
h1 = tf.concat(1, [h1, y], name = 'active1_concat_y') h2 = tf.nn.relu(batch_norm_layer(fully_connected(h1, 128 * 49, 'g_fully_connected2'),
is_train = train, name = 'g_bn2'))
h2 = tf.reshape(h2, [64, 7, 7, 128], name = 'h2_reshape')
# 把约束条件和上一层拼接起来
h2 = conv_cond_concat(h2, yb, name = 'active2_concat_y') h3 = tf.nn.relu(batch_norm_layer(deconv2d(h2, [64,14,14,128],
name = 'g_deconv2d3'),
is_train = train, name = 'g_bn3'))
h3 = conv_cond_concat(h3, yb, name = 'active3_concat_y') # 经过一个 sigmoid 函数把值归一化为 0~1 之间,
h4 = tf.nn.sigmoid(deconv2d(h3, [64, 28, 28, 1],
name = 'g_deconv2d4'), name = 'generate_image') return h4 # 定义判别器
def discriminator(image, y, reuse = False): # 因为真实数据和生成数据都要经过判别器,所以需要指定 reuse 是否可用
if reuse:
tf.get_variable_scope().reuse_variables() # 同生成器一样,判别器也需要把约束条件串联进来
yb = tf.reshape(y, [BATCH_SIZE, 1, 1, 10], name = 'yb')
x = conv_cond_concat(image, yb, name = 'image_concat_y') # 卷积,激活,串联条件。
h1 = lrelu(conv2d(x, 11, name = 'd_conv2d1'), name = 'lrelu1')
h1 = conv_cond_concat(h1, yb, name = 'h1_concat_yb') h2 = lrelu(batch_norm_layer(conv2d(h1, 74, name = 'd_conv2d2'),
name = 'd_bn2'), name = 'lrelu2')
h2 = tf.reshape(h2, [BATCH_SIZE, -1], name = 'reshape_lrelu2_to_2d')
h2 = tf.concat(1, [h2, y], name = 'lrelu2_concat_y') h3 = lrelu(batch_norm_layer(fully_connected(h2, 1024, name = 'd_fully_connected3'),
name = 'd_bn3'), name = 'lrelu3')
h3 = tf.concat(1,[h3, y], name = 'lrelu3_concat_y') # 全连接层,输出以为 loss 值
h4 = fully_connected(h3, 1, name = 'd_result_withouts_sigmoid') return tf.nn.sigmoid(h4, name = 'discriminator_result_with_sigmoid'), h4 # 定义训练过程中的采样函数
def sampler(z, y, train = True):
tf.get_variable_scope().reuse_variables()
return generator(z, y, train = train)

可以看到,生成器由 7 × 7  变为 14 × 14 再变为 28 × 28大小,每一层都加入了约束条件 y,完美的诠释了论文所给出的网络,之所以要加入 is_train 参数,是由于 Batch_norm 层中训练和测试的时候的过程是不同的,用这个参数区分训练和测试,生成器的最后一层,用了一个 sigmoid 函数把值归一化到 0~1 之间,如果是不加约束的网络,则用 tanh 函数,所以在 save_images 函数中要用到语句:img = (images + 1.0) / 2.0。

sampler 函数的作用是在训练过程中对生成器生成的图片进行采样,所以这个函数必须指定 reuse 可用,关于 reuse 说明,请看:http://www.cnblogs.com/Charles-Wan/p/6200446.html。

参考资料:

1. https://github.com/carpedm20/DCGAN-tensorflow

不要怂,就是GAN (生成式对抗网络) (三):判别器和生成器 TensorFlow Model的更多相关文章

  1. GAN生成式对抗网络(三)——mnist数据生成

    通过GAN生成式对抗网络,产生mnist数据 引入包,数据约定等 import numpy as np import matplotlib.pyplot as plt import input_dat ...

  2. GAN生成式对抗网络(四)——SRGAN超高分辨率图片重构

    论文pdf 地址:https://arxiv.org/pdf/1609.04802v1.pdf 我的实际效果 清晰度距离我的期待有距离. 颜色上面存在差距. 解决想法 增加一个颜色判别器.将颜色值反馈 ...

  3. GAN生成式对抗网络(一)——原理

    生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型 GAN包括两个核心模块. 1.生成器模块 --generator 2.判别器模块--de ...

  4. 不要怂,就是GAN (生成式对抗网络) (一)

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  5. 不要怂,就是GAN (生成式对抗网络) (一): GAN 简介

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  6. 不要怂,就是GAN (生成式对抗网络) (二)

    前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...

  7. 不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN

    在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 train.py,同时新建文件夹 logs 和文件夹 samples,前者用来保存训练过程中的日志和模型,后者用来保 ...

  8. 不要怂,就是GAN (生成式对抗网络) (二):数据读取和操作

    前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...

  9. 不要怂,就是GAN (生成式对抗网络) (六):Wasserstein GAN(WGAN) TensorFlow 代码

    先来梳理一下我们之前所写的代码,原始的生成对抗网络,所要优化的目标函数为: 此目标函数可以分为两部分来看: ①固定生成器 G,优化判别器 D, 则上式可以写成如下形式: 可以转化为最小化形式: 我们编 ...

随机推荐

  1. C#.NET股票历史数据采集,【附18年历史数据和源代码】

    阅读目录 1.数据采集需求 2.股市数据接口 3.数据库设计 4.关键信息采集 5.源代码和数据库 如果用知乎,可以关注专栏:.NET开源项目和PowerBI社区 重点重点:我没有买股票,没有买股票, ...

  2. Chrome 解决flash问题

    Chrome 无法显示使用插件的内容 Chrome 不再支持很多插件.不过网站创建者已经通过更安全的方式,将多数这类功能添加到 Chrome 中. 为什么 NPAPI 插件现在无法正常运行过去,许多插 ...

  3. CentOS 6.0 VNC远程桌面配置方法(转帖)

    问题:新装开发机,安装VNC软件后,按照下面文档配置后,无法用VNC view连接,关闭防火墙后可以连上 解决方法:说明问题出在防火墙配置上,除了允许15900端口外,还有其他要设,经过排查后,加上如 ...

  4. Java进行spark计算

    首先在Linux环境安装spark: 可以从如下地址下载最新版本的spark: https://spark.apache.org/downloads.html 这个下载下来后是个tgz的压缩包,解压后 ...

  5. Linux新手入门:Unable to locate package错误解决办法

    最近刚开始接触Linux,在虚拟机中装了个Ubuntu,当前的版本是Ubuntu 11.10,装好后自然少不了安装一些软件,在设置了软件的源后,就开始了 sudo apt-get install,结果 ...

  6. 分类预测输出precision,recall,accuracy,auc和tp,tn,fp,fn矩阵

    此次我做的实验是二分类问题,输出precision,recall,accuracy,auc # -*- coding: utf-8 -*- #from sklearn.neighbors import ...

  7. 管理Linux服务器的用户和组

    管理Linux服务器的用户和组 Linux操作系统是一个多用户多任务的操作系统,允许多个用户同时登录到系统,使用系统资源. 为了使所有用户的工作顺利进行,保护每个用户的文件和进程,规范每个用户的权限, ...

  8. Android开发入门——Button绑定监听事件三种方式

    import android.app.Activity; import android.os.Bundle;import android.view.View;import android.widget ...

  9. 关于TP5中的依赖注入和容器和facade

    看了不少的文章,也看了官方的介绍,还是根据自己的理解,写写看法,理清下思路 只是单纯的说依赖注入Dependency Injection和容器 别的不白扯 比如有A,B,C三个类 A类的1方法依赖B类 ...

  10. servlet的小例子

    servlet测试 首先,打开myeclipse,file|new|Dynamic Web Project 会出现一个对话框,在Project name的文本框中输入:FirstServlet:然后点 ...