GAN网络架构分析

上图即为GAN的逻辑架构,其中的noise vector就是特征向量z,real images就是输入变量x,标签的标准比较简单(二分类么),real的就是tf.ones,fake的就是tf.zeros。

网络具体形状大体如上,具体数值有所调整,生成器过程为:噪声向量-全连接-卷积-卷积-卷积,辨别器过程:图片-卷积-卷积-全连接-全连接。

和预想的不同,实际上数据在生成器中并不是从无到有由小变大的过程,而是由3136(56*56)经过正常卷积步骤下降为28*28的过程。

实现如下:

import datetime
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('../../Mnist_data') """测试数据""" # sample_image = mnist.train.next_batch(1)[0]
# print(sample_image.shape)
# sample_image = sample_image.reshape([28, 28])
# plt.imshow(sample_image, cmap='Greys') """分辨器""" def discriminator(images, reuse=None):
with tf.variable_scope(tf.get_variable_scope(), reuse=reuse) as scope:
# 卷积 + 激活 + 池化
d_w1 = tf.get_variable('d_w1',[5,5,1,32],initializer=tf.truncated_normal_initializer(stddev=0.02))
d_b1 = tf.get_variable('d_b1',[32],initializer=tf.constant_initializer(0))
d1 = tf.nn.conv2d(input=images,filter=d_w1,strides=[1,1,1,1],padding='SAME')
d1 = d1 + d_b1
d1 = tf.nn.relu(d1)
d1 = tf.nn.avg_pool(d1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME') # 卷积 + 激活 + 池化
d_w2 = tf.get_variable('d_w2',[5,5,32,64],initializer=tf.truncated_normal_initializer(stddev=0.02))
d_b2 = tf.get_variable('d_b2',[64],initializer=tf.constant_initializer(0))
d2 = tf.nn.conv2d(input=d1,filter=d_w2,strides=[1,1,1,1],padding='SAME')
d2 = d2 + d_b2
d2 = tf.nn.relu(d2)
d2 = tf.nn.avg_pool(d2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME') # 全连接 + 激活
d_w3 = tf.get_variable('d_w3',[7 * 7 * 64,1024],initializer=tf.truncated_normal_initializer(stddev=0.02))
d_b3 = tf.get_variable('d_b3',[1024],initializer=tf.constant_initializer(0))
d3 = tf.reshape(d2,[-1,7 * 7 * 64])
d3 = tf.matmul(d3,d_w3)
d3 = d3 + d_b3
d3 = tf.nn.relu(d3) # 全连接
d_w4 = tf.get_variable('d_w4',[1024,1],initializer=tf.truncated_normal_initializer(stddev=0.02))
d_b4 = tf.get_variable('d_b4',[1],initializer=tf.constant_initializer(0))
d4 = tf.matmul(d3,d_w4) + d_b4 # 最后输出一个非尺度化的值
return d4 """生成器""" def generator(z, batch_size, z_dim, reuse=False):
'''接收特征向量z,由z生成图片''' with tf.variable_scope(tf.get_variable_scope(),reuse=reuse):
# 全连接 + 批正则化 + 激活
# z_dim -> 3136 -> 56*56*1
g_w1 = tf.get_variable('g_w1', [z_dim, 3136], dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.02))
g_b1 = tf.get_variable('g_b1', [3136], initializer=tf.truncated_normal_initializer(stddev=0.02))
g1 = tf.matmul(z, g_w1) + g_b1
g1 = tf.reshape(g1, [-1, 56, 56, 1])
g1 = tf.contrib.layers.batch_norm(g1, epsilon=1e-5, scope='bn1')
g1 = tf.nn.relu(g1) # 卷积 + 批正则化 + 激活
g_w2 = tf.get_variable('g_w2',[3,3,1,z_dim / 2],dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.02))
g_b2 = tf.get_variable('g_b2',[z_dim / 2],initializer=tf.truncated_normal_initializer(stddev=0.02))
g2 = tf.nn.conv2d(g1,g_w2,strides=[1,2,2,1],padding='SAME')
g2 = g2 + g_b2
g2 = tf.contrib.layers.batch_norm(g2,epsilon=1e-5,scope='bn2')
g2 = tf.nn.relu(g2)
g2 = tf.image.resize_images(g2,[56,56]) # 卷积 + 批正则化 + 激活
g_w3 = tf.get_variable('g_w3',[3,3,z_dim / 2,z_dim / 4],dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.02))
g_b3 = tf.get_variable('g_b3',[z_dim / 4],initializer=tf.truncated_normal_initializer(stddev=0.02))
g3 = tf.nn.conv2d(g2,g_w3,strides=[1,2,2,1],padding='SAME')
g3 = g3 + g_b3
g3 = tf.contrib.layers.batch_norm(g3,epsilon=1e-5,scope='bn3')
g3 = tf.nn.relu(g3)
g3 = tf.image.resize_images(g3,[56,56]) # 卷积 + 激活
g_w4 = tf.get_variable('g_w4',[1,1,z_dim / 4,1],dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.02))
g_b4 = tf.get_variable('g_b4',[1],initializer=tf.truncated_normal_initializer(stddev=0.02))
g4 = tf.nn.conv2d(g3,g_w4,strides=[1,2,2,1],padding='SAME')
g4 = g4 + g_b4
g4 = tf.sigmoid(g4) # 输出g4的维度: batch_size x 28 x 28 x 1
return g4

逻辑实现如下,不同组成部分的loss值是分开计算的:

"""逻辑架构"""

tf.reset_default_graph()
batch_size = 50
z_dimensions = 100 z_placeholder = tf.placeholder(tf.float32, [None, z_dimensions], name='z_placeholder')
x_placeholder = tf.placeholder(tf.float32, shape = [None,28,28,1], name='x_placeholder') Gz = generator(z_placeholder, batch_size, z_dimensions) # 根据z生成伪造图片
Dx = discriminator(x_placeholder) # 辨别器辨别真实图片
Dg = discriminator(Gz, reuse=True) # 辨别器辨别伪造图片 #discriminator 的loss 分为两部分
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dx, labels = tf.ones_like(Dx)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dg, labels = tf.zeros_like(Dg)))
d_loss=d_loss_real + d_loss_fake
# Generator的目标是生成尽可能真实的图像,所以计算Dg和1的loss
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dg, labels = tf.ones_like(Dg)))

优化器部分有一些注意点:

"""优化部分"""

# 由于训练时生成器和辨别器是分开训练的,
# 所以不同的训练过程对应的优化参数是要做区分的
tvars = tf.trainable_variables() d_vars = [var for var in tvars if 'd_' in var.name]
g_vars = [var for var in tvars if 'g_' in var.name] d_trainer_real = tf.train.AdamOptimizer(0.0003).minimize(d_loss_real, var_list=d_vars)
d_trainer_fake = tf.train.AdamOptimizer(0.0003).minimize(d_loss_fake, var_list=d_vars)
d_trainer = tf.train.AdamOptimizer(0.0003).minimize(d_loss, var_list=d_vars)
g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars)

入注释所说,训练不同的位置,优化不同的参数,不可以混淆,所以这里就涉及了tf变量提取的手法,结果展示如下:

import pprint
pp = pprint.PrettyPrinter()
pp.pprint(d_vars)
pp.pprint(g_vars) [<tf.Variable 'd_w1:0' shape=(5, 5, 1, 32) dtype=float32_ref>,
<tf.Variable 'd_b1:0' shape=(32,) dtype=float32_ref>,
<tf.Variable 'd_w2:0' shape=(5, 5, 32, 64) dtype=float32_ref>,
<tf.Variable 'd_b2:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'd_w3:0' shape=(3136, 1024) dtype=float32_ref>,
<tf.Variable 'd_b3:0' shape=(1024,) dtype=float32_ref>,
<tf.Variable 'd_w4:0' shape=(1024, 1) dtype=float32_ref>,
<tf.Variable 'd_b4:0' shape=(1,) dtype=float32_ref>] [<tf.Variable 'g_w1:0' shape=(100, 3136) dtype=float32_ref>,
<tf.Variable 'g_b1:0' shape=(3136,) dtype=float32_ref>,
<tf.Variable 'g_w2:0' shape=(3, 3, 1, 50) dtype=float32_ref>,
<tf.Variable 'g_b2:0' shape=(50,) dtype=float32_ref>,
<tf.Variable 'g_w3:0' shape=(3, 3, 50, 25) dtype=float32_ref>,
<tf.Variable 'g_b3:0' shape=(25,) dtype=float32_ref>,
<tf.Variable 'g_w4:0' shape=(1, 1, 25, 1) dtype=float32_ref>,
<tf.Variable 'g_b4:0' shape=(1,) dtype=float32_ref>]

之后是训练过程:

"""迭代训练"""

sess = tf.Session()
sess.run(tf.global_variables_initializer()) # 对discriminator的预训练
for i in range(300):
print('.',end='')
z_batch = np.random.normal(0, 1, size=[batch_size, z_dimensions])
real_image_batch = mnist.train.next_batch(batch_size)[0].reshape([batch_size, 28, 28, 1])
# 用real and fake images分别对discriminator训练
_, __, dLossReal, dLossFake = sess.run([d_trainer_real, d_trainer_fake, d_loss_real, d_loss_fake],
{x_placeholder: real_image_batch, z_placeholder: z_batch}) if (i % 100 == 0):
print("\rdLossReal:",dLossReal,"dLossFake:",dLossFake) # 交替训练 generator和discriminator
for i in range(100000):
print('.',end='')
real_image_batch = mnist.train.next_batch(batch_size)[0].reshape([batch_size, 28, 28, 1])
z_batch = np.random.normal(0, 1, size=[batch_size, z_dimensions]) # 用real and fake images同时对discriminator训练
_,dLossReal,dLossFake = sess.run([d_trainer,d_loss_real,d_loss_fake],
{x_placeholder: real_image_batch,z_placeholder: z_batch})
# 训练generator
z_batch = np.random.normal(0,1,size=[batch_size,z_dimensions])
_ = sess.run(g_trainer,feed_dict={z_placeholder: z_batch}) if i % 100 == 0:
# 每 100 iterations, 输出一个生成的图像
print("\rIteration:",i,"at",datetime.datetime.now())
z_batch = np.random.normal(0,1,size=[1,z_dimensions])
generated_images = generator(z_placeholder,1,z_dimensions, reuse=True)
images = sess.run(generated_images,{z_placeholder: z_batch})
plt.imshow(images[0].reshape([28,28]),cmap='Greys')
plt.show()
# 输出discriminator的值
im = images[0].reshape([1,28,28,1])
result = discriminator(x_placeholder, reuse=True)
estimate = sess.run(result,{x_placeholder: im})
print("Estimate:",np.squeeze(estimate))

先预训练分辨器,

然后交替训练分辨器和生成器。

其实是有一点图片可以展示的,但是我的电脑性能太渣(苏菲4),跑了600轮左右的迭代我实在于心不忍了,先搁置吧... 以后有机会回实验室在说,至少原理是体会到了。

共享变量

『TensorFlow』线程控制器类&变量作用域理解加深

之前看文档时体会不深,现在大体明白共享变量的存在意义了,它是在设计计算图时考虑的:

同一个变量如果有不同的数据流(计算图中不同的节点在不同的时刻去给同一个节点的同一个输入位置提供数据),

  • Variable变量会之间创建两个不同的变量节点去接收不同的数据流
  • get_variable变量在reuse为True时会使用同一个变量应付不同的数据流

这也就是共享变量的应用之处。这在上面的程序中体现在判别器的任务,如果接收到的是生成器生成的图像,判别器就尝试优化自己的网络结构来使自己输出0,如果接收到的是来自真实数据的图像,那么就尝试优化自己的网络结构来使自己输出1。也就是说,fake图像和real图像经过判别器的时候,要共享同一套变量,所以TensorFlow引入了变量共享机制,而和正常的卷积网络不同的是这里的fake和real变量并不在同一个计算图节点位置(real图片在x节点处输入,而fake图则在生成器输出节点位置计入计算图)。

『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上的更多相关文章

  1. 『TensorFlow』通过代码理解gan网络_中

    『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...

  2. 『cs231n』通过代码理解风格迁移

    『cs231n』卷积神经网络的可视化应用 文件目录 vgg16.py import os import numpy as np import tensorflow as tf from downloa ...

  3. 『cs231n』RNN之理解LSTM网络

    概述 LSTM是RNN的增强版,1.RNN能完成的工作LSTM也都能胜任且有更好的效果:2.LSTM解决了RNN梯度消失或爆炸的问题,进而可以具有比RNN更为长时的记忆能力.LSTM网络比较复杂,而恰 ...

  4. 『cs231n』计算机视觉基础

    线性分类器损失函数明细: 『cs231n』线性分类器损失函数 最优化Optimiz部分代码: 1.随机搜索 bestloss = float('inf') # 无穷大 for num in range ...

  5. 『cs231n』作业3问题2选讲_通过代码理解LSTM网络

    LSTM神经元行为分析 LSTM 公式可以描述如下: itftotgtctht=sigmoid(Wixxt+Wihht−1+bi)=sigmoid(Wfxxt+Wfhht−1+bf)=sigmoid( ...

  6. 『cs231n』作业2选讲_通过代码理解Dropout

    Dropout def dropout_forward(x, dropout_param): p, mode = dropout_param['p'], dropout_param['mode'] i ...

  7. 『cs231n』作业3问题3选讲_通过代码理解图像梯度

    Saliency Maps 这部分想探究一下 CNN 内部的原理,参考论文 Deep Inside Convolutional Networks: Visualising Image Classifi ...

  8. 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练

    一份不错的作业3资料(含答案) RNN神经元理解 单个RNN神经元行为 括号中表示的是维度 向前传播 def rnn_step_forward(x, prev_h, Wx, Wh, b): " ...

  9. 『cs231n』作业2选讲_通过代码理解优化器

    1).Adagrad一种自适应学习率算法,实现代码如下: cache += dx**2 x += - learning_rate * dx / (np.sqrt(cache) + eps) 这种方法的 ...

随机推荐

  1. P2051 [AHOI2009]中国象棋(动态规划)

    思路 好像是一道挺水的计数的,不知道为什么会是紫题 显然每行和每列最多放两个 首先考虑状压,然后发现三进制状压可做,但是三进制太麻烦了,可以拆成两个二进制,一个表示该列是否是放了一个的,一个表示该列是 ...

  2. Docker Engine SDKs and API 的开发2

    Examples using the Docker Engine SDKs and Docker API After you install Docker, you can install the G ...

  3. Python-ConfigParser获取配置项名称大小写问题

    C:\Python27\Lib\ConfigParser.py: def optionxform(self, optionstr): return optionstr.lower() 会将配置文件中的 ...

  4. 51nod 1624 取余最长路

    http://www.51nod.com/onlineJudge/questionCode.html#!problemId=1624 题意: 思路:因为一共只有3行,所以只需要确定第一行和第二行的转折 ...

  5. 【Mysql】key 、primary key 、unique key 与index区别

    参考:https://blog.csdn.net/nanamasuda/article/details/52543177 总的来说,primary key .unique key 这些key建立的同时 ...

  6. 【Django】【Shell】

    django-admin startproject guest python manage.py startapp sign python manage.py runserver 127.0.0.1: ...

  7. Codeforces 767E Change-free

    题目链接:http://codeforces.com/contest/767/problem/E 居然是一个瞎几把贪心(E比B水系列) 考虑要每一次操作至少要用${\left \lfloor \fra ...

  8. go 接口以及对象传递

    // Sample program to show how to use an interface in Go. package main import ( "fmt" ) // ...

  9. MySQL基本使用

    来自李兴华视频. 1. 启动命令行方式 2. 连接mysql数据库,其中“-u”标记的是输入用户名,“-p”标记的是输入密码. 3. 建立一个新数据库——mldn,使用UTF-8编码: create ...

  10. 关于Oracle 12C pdb用户无法登录的问题

    新装了oracle12c,对新的CDB和PDB用户如何登录一直一头雾水,经过一晚上的查找,终于解决. sqlplus /nolog -> conn /as sysdba 登录到oracle 将s ...