代码实现

当初学习时,主要学习的这个博客 https://xyang35.github.io/2017/08/22/GAN-1/ ,写的挺好的。

本文目的,用GAN实现最简单的例子,帮助认识GAN算法。

import numpy as np
from matplotlib import pyplot as plt
batch_size = 4

2. 真实数据集,我们要通过GAN学习这个数据集,然后生成和他分布规则一样的数据集

X = np.random.normal(size=(1000, 2))
A = np.array([[1, 2], [-0.1, 0.5]])
b = np.array([1, 2])
X = np.dot(X, A) + b plt.scatter(X[:, 0], X[:, 1])
plt.show() # 等会通过这个函数,不断从中取x值,取值数量为batch_size
def iterate_minibatch(x, batch_size, shuffle=True):
indices = np.arange(x.shape[0])
if shuffle:
np.random.shuffle(indices) for i in range(0, x.shape[0], batch_size):
yield x[indices[i:i + batch_size], :]

3.封装GAN对象

包含生成器,判别器

class GAN(object):
def __init__(self):
#初始函数,在这里对初始化模型
def netG(self, z):
#生成器模型
def netD(self, x, reuse=False):
#判别器模型

4.生成器netG

随意输入的z,通过z*w+b的矩阵运算(全连接运算),返回结果

    def netG(self, z):
"""1-layer fully connected network""" with tf.variable_scope("generator") as scope:
W = tf.get_variable(name="g_W", shape=[2, 2],
initializer=tf.contrib.layers.xavier_initializer(),
trainable=True)
b = tf.get_variable(name="g_b", shape=[2],
initializer=tf.zeros_initializer(),
trainable=True)
return tf.matmul(z, W) + b

5.判别器nefD

判别器为三层全连接网络。隐层部分使用tanh激活函数。输出部分没有激活函数

    def netD(self, x, reuse=False):
"""3-layer fully connected network""" with tf.variable_scope("discriminator") as scope:
if reuse:
scope.reuse_variables() W1 = tf.get_variable(name="d_W1", shape=[2, 5],
initializer=tf.contrib.layers.xavier_initializer(),
trainable=True)
b1 = tf.get_variable(name="d_b1", shape=[5],
initializer=tf.zeros_initializer(),
trainable=True)
W2 = tf.get_variable(name="d_W2", shape=[5, 3],
initializer=tf.contrib.layers.xavier_initializer(),
trainable=True)
b2 = tf.get_variable(name="d_b2", shape=[3],
initializer=tf.zeros_initializer(),
trainable=True)
W3 = tf.get_variable(name="d_W3", shape=[3, 1],
initializer=tf.contrib.layers.xavier_initializer(),
trainable=True)
b3 = tf.get_variable(name="d_b3", shape=[1],
initializer=tf.zeros_initializer(),
trainable=True) layer1 = tf.nn.tanh(tf.matmul(x, W1) + b1)
layer2 = tf.nn.tanh(tf.matmul(layer1, W2) + b2)
return tf.matmul(layer2, W3) + b3

6.初始化__init__函数

def __init__(self):
# input, output
#占位变量,等会用来保存随机产生的数,
self.z = tf.placeholder(tf.float32, shape=[None, 2], name='z')
#占位变量,真实数据的
self.x = tf.placeholder(tf.float32, shape=[None, 2], name='real_x') # define the network
#生成器,对随机变量进行加工,产生伪造的数据
self.fake_x = self.netG(self.z) #判别器对真实数据进行判别,返回判别结果
#reuse=false,表示不是共享变量,需要tensorflow开辟变量地址
self.real_logits = self.netD(self.x, reuse=False) #判别器对伪造数据进行判别,返回判别结果
#reuse=true,表示是共享变量,复用netD中已有的变量
self.fake_logits = self.netD(self.fake_x, reuse=True) # define losses
#判定器的损失值,将真实数据的判定为真实数据,将伪造数据的判断为伪造数据的得分情况
self.loss_D = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.real_logits,
labels=tf.ones_like(self.real_logits))) + \
tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
labels=tf.zeros_like(self.real_logits)))
#生成器的生成分数。伪造的数据,别判断器判定为真实数据的得分情况
self.loss_G = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
labels=tf.ones_like(self.real_logits))) # collect variables
t_vars = tf.trainable_variables()
#存放判别器中用到的变量
self.d_vars = [var for var in t_vars if 'd_' in var.name]
#存放生成器中用到的变量
self.g_vars = [var for var in t_vars if 'g_' in var.name]

7.开始训练

gan = GAN()

#使用随机梯度下降
d_optim = tf.train.AdamOptimizer(learning_rate=0.05).minimize(gan.loss_D, var_list=gan.d_vars)
g_optim = tf.train.AdamOptimizer(learning_rate=0.01).minimize(gan.loss_G, var_list=gan.g_vars) init = tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init)
#将数据循环10次
for epoch in range(10):
avg_loss = 0.
count = 0
#从真实数据集当中,随机抓取batch_size数量个值
for x_batch in iterate_minibatch(X, batch_size=batch_size):
# generate noise z
#随机变量,数量为batch_size
z_batch = np.random.normal(size=(4, 2)) # update D network
#将拿到的真实数据值和随机生成的数值,喂养给sess,并bp优化一次
loss_D, _ = sess.run([gan.loss_D, d_optim],
feed_dict={
gan.z: z_batch,
gan.x: x_batch,
}) # update G network
loss_G, _ = sess.run([gan.loss_G, g_optim],
feed_dict={
gan.z: z_batch,
gan.x: np.zeros(z_batch.shape), # dummy input
}) avg_loss += loss_D
count += 1 avg_loss /= count
#每一个epoch都展示一次生成效果
z = np.random.normal(size=(100, 2))
# 随机生成100个数值,0到1000---用来从真实值里面取数据
excerpt = np.random.randint(1000, size=1000)
fake_x, real_logits, fake_logits = sess.run([gan.fake_x, gan.real_logits, gan.fake_logits],
feed_dict={gan.z: z, gan.x: X[excerpt, :]})
accuracy = 0.5 * (np.sum(real_logits > 0.5) / 100. + np.sum(fake_logits < 0.5) / 100.)
print('\ndiscriminator loss at epoch %d: %f' % (epoch, avg_loss))
print('\ndiscriminator accuracy at epoch %d: %f' % (epoch, accuracy))
plt.scatter(X[:, 0], X[:, 1])
plt.scatter(fake_x[:, 0], fake_x[:, 1])
plt.show()

效果

完整代码下载

欢迎转载,转载请注明出处。欢迎沟通交流: panfengqqs@qq.com)

GAN生成式对抗网络(二)——tensorflow代码示例的更多相关文章

  1. GAN生成式对抗网络(四)——SRGAN超高分辨率图片重构

    论文pdf 地址:https://arxiv.org/pdf/1609.04802v1.pdf 我的实际效果 清晰度距离我的期待有距离. 颜色上面存在差距. 解决想法 增加一个颜色判别器.将颜色值反馈 ...

  2. GAN生成式对抗网络(三)——mnist数据生成

    通过GAN生成式对抗网络,产生mnist数据 引入包,数据约定等 import numpy as np import matplotlib.pyplot as plt import input_dat ...

  3. GAN生成式对抗网络(一)——原理

    生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型 GAN包括两个核心模块. 1.生成器模块 --generator 2.判别器模块--de ...

  4. 不要怂,就是GAN (生成式对抗网络) (一)

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  5. 不要怂,就是GAN (生成式对抗网络) (一): GAN 简介

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  6. 不要怂,就是GAN (生成式对抗网络) (六):Wasserstein GAN(WGAN) TensorFlow 代码

    先来梳理一下我们之前所写的代码,原始的生成对抗网络,所要优化的目标函数为: 此目标函数可以分为两部分来看: ①固定生成器 G,优化判别器 D, 则上式可以写成如下形式: 可以转化为最小化形式: 我们编 ...

  7. 不要怂,就是GAN (生成式对抗网络) (五):无约束条件的 GAN 代码与网络的 Graph

    GAN 这个领域发展太快,日新月异,各种 GAN 层出不穷,前几天看到一篇关于 Wasserstein GAN 的文章,讲的很好,在此把它分享出来一起学习:https://zhuanlan.zhihu ...

  8. 不要怂,就是GAN (生成式对抗网络) (二)

    前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...

  9. 不要怂,就是GAN (生成式对抗网络) (二):数据读取和操作

    前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...

随机推荐

  1. 纯css实现移动端横向滑动列表&&overflow:atuo;隐藏滚动条

    <!DOCTYPE html> <html> <head> <title>横向滑动</title> <style type=" ...

  2. 接口请求 URL转码

    什么是URL转码 不管是以何种方式传递url时,如果要传递的url中包含特殊字符,如想要传递一个+,但是这个+会被url会被编码成空格,想要传递&,被url处理成分隔符. 尤其是当传递的url ...

  3. Tomcat安装及其目录结构介绍

    Tomcat服务器是一个免费的开放源代码的Web应用服务器,属于轻量级应用服务器,在中小型系统和并发访问用户不是很多的场合下被普遍使用,是开发和调试JSP程序的首选. Tomcat的安装版本有绿色解压 ...

  4. .NET CORE 下 MariaDB DBfirst 生成model层 并配置连接参数

    1.首先新建一个类库,然后通过NuGet安装下面三个包 2.然后在程序包管理器控制台中运行以下代码(ps:记得默认项目选择刚才新建的项目,同时设置为启动项) server 是服务器地址 databas ...

  5. ES6 Proise 简单理解

    Promise 这是ES6中增加的一个处理异步的对象. 传统变成写异步函数的时候,经常会遇到回调套回调: Promise 是异步编程的一种解决方案,比传统的解决方案 -----回调函数和事件----- ...

  6. # marshalsec使用

    开启rmi服务,恶意类放到服务上 D:\jdk_1.8\bin\java.exe -cp marshalsec-0.0.3-SNAPSHOT-all.jar marshalsec.jndi.RMIRe ...

  7. 在浏览器输入 URL 回车之后发生了什么

    注意:本文的步骤是建立在,请求的是一个简单的 HTTP 请求,没有 HTTPS.HTTP2.最简单的 DNS.没有代理.并且服务器没有任何问题的基础上. 大致流程 URL 解析 DNS 查询 TCP ...

  8. redis集群1

    redis-trib.rb命令详解   redis-trib.rb是官方提供的Redis Cluster的管理工具,无需额外下载,默认位于源码包的src目录下,但因该工具是用ruby开发的,所以需要准 ...

  9. python连接impala时,执行SQL报错expecting list of size 2 for struct args

    这个错误困扰了好久,因为集群有多台,暂放到其他几台机器上执行了SQL操作,一直在找解决方法,无意间得到真传,喜出望外啊 报错信息: Traceback (most recent call last): ...

  10. rabbit MQ 的环境及命令使用(一)

    RabbitMQ依赖erlang,所以先安装erlang,然后再安装RabbitMQ; 先安装erlang,双击erlang的安装文件即可,然后配置环境变量: ERLANG_HOME=D:\Progr ...