GAN网络原理介绍和代码
GAN网络的整体公式:

公式各参数介绍如下:
X是真实地图片,而对应的标签是1。
G(Z)是通过给定的噪声Z,生成图片(实际上是通过给定的Z生成一个tensor),对应的标签是0。
D是一个二分类网络,对于给定的图片判别真假。
D和G的参数更新方式:
D通过输入的真假图片,通过BCE(二分类交叉熵)更新自己的参数。
D对G(Z)生成的标签L,G尽可能使L为true,也就是1,通过BCE(二分类交叉熵)更新自己的参数。
公式演变:
对于G来说要使D无法判别自己生成的图片是假的,故而要使G(Z)越大越好,所以就使得V(G,D)越小越好;而对于D,使G(Z)越小D(X)越大,故而使V(G,D)越大越好

为了便于求导,故而加了log,变为如下:

最后对整个batch求期望,变为如下:

基于mnist实现的GAN网络结构对应的代码
import itertools
import math
import time import torch
import torchvision
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from IPython import display
from torch.autograd import Variable
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]) train_dataset = dsets.MNIST(root='./data/', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True) class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
) def forward(self, x):
out = self.model(x.view(x.size(0), 784))
out = out.view(out.size(0), -1)
return out class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 784),
nn.Tanh()
) def forward(self, x):
x = x.view(x.size(0), -1)
out = self.model(x)
return out discriminator = Discriminator().cuda()
generator = Generator().cuda()
criterion = nn.BCELoss()
lr = 0.0002
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr) def train_discriminator(discriminator, images, real_labels, fake_images, fake_labels):
discriminator.zero_grad()
outputs = discriminator(images)
real_loss = criterion(outputs, real_labels)
real_score = outputs outputs = discriminator(fake_images)
fake_loss = criterion(outputs, fake_labels)
fake_score = outputs d_loss = real_loss + fake_loss
d_loss.backward()
d_optimizer.step()
return d_loss, real_score, fake_score
def train_generator(generator, discriminator_outputs, real_labels):
generator.zero_grad()
g_loss = criterion(discriminator_outputs, real_labels)
g_loss.backward()
g_optimizer.step()
return g_loss # draw samples from the input distribution to inspect the generation on training
num_test_samples = 16
test_noise = Variable(torch.randn(num_test_samples, 100).cuda())
# create figure for plotting
size_figure_grid = int(math.sqrt(num_test_samples))
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(6, 6))
for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
ax[i, j].get_xaxis().set_visible(False)
ax[i, j].get_yaxis().set_visible(False) # set number of epochs and initialize figure counter
num_epochs = 200
num_batches = len(train_loader)
num_fig = 0 for epoch in range(num_epochs):
for n, (images, _) in enumerate(train_loader):
images = Variable(images.cuda())
real_labels = Variable(torch.ones(images.size(0)).cuda()) # Sample from generator
noise = Variable(torch.randn(images.size(0), 100).cuda())
fake_images = generator(noise)
fake_labels = Variable(torch.zeros(images.size(0)).cuda()) # Train the discriminator
d_loss, real_score, fake_score = train_discriminator(discriminator, images, real_labels, fake_images,
fake_labels) # Sample again from the generator and get output from discriminator
noise = Variable(torch.randn(images.size(0), 100).cuda())
fake_images = generator(noise)
outputs = discriminator(fake_images) # Train the generator
g_loss = train_generator(generator, outputs, real_labels) if (n + 1) % 100 == 0:
test_images = generator(test_noise) for k in range(num_test_samples):
i = k // 4
j = k % 4
ax[i, j].cla()
ax[i, j].imshow(test_images[k, :].data.cpu().numpy().reshape(28, 28), cmap='Greys')
display.clear_output(wait=True)
display.display(plt.gcf()) plt.savefig('results/mnist-gan-%03d.png' % num_fig)
num_fig += 1
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
'D(x): %.2f, D(G(z)): %.2f'
% (epoch + 1, num_epochs, n + 1, num_batches, d_loss.data[0], g_loss.data[0],
real_score.data.mean(), fake_score.data.mean())) fig.close()
GAN网络原理介绍和代码的更多相关文章
- When I see you again(加密原理介绍,代码实现DES、AES、RSA、Base64、MD5)
关于网络安全的数据加密部分,本来打算总结一篇博客搞定,没想到东西太多,这已是第三篇了,而且这篇写了多次,熬了多次夜,真是again and again.起个名字:数据加密三部曲,前两部链接如下: 整体 ...
- 加密原理介绍,代码实现DES、AES、RSA、Base64、MD5
阅读目录 github下载地址 一.DES对称加密 二.AES对称加密 三.RSA非对称加密 四.实际使用 五.关于Padding 关于电脑终端Openssl加密解密命令 关于网络安全的数据加密部分, ...
- TF实战:(Mask R-CNN原理介绍与代码实现)-Chapter-8
二值掩膜输出依据种类预测分支(Faster R-CNN部分)预测结果:当前RoI的物体种类为i第i个二值掩膜输出就是该RoI的损失Lmask 对于预测的二值掩膜输出,我们对每个像素点应用sigmoid ...
- 『TensorFlow』通过代码理解gan网络_中
『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...
- 常见的GAN网络的相关原理及推导
常见的GAN网络的相关原理及推导 在上一篇中我们给大家介绍了GAN的相关原理和推导,GAN是VAE的后一半,再加上一个鉴别网络.这样而导致了完全不同的训练方式. GAN,生成对抗网络,主要有两部分构成 ...
- GAN网络从入门教程(一)之GAN网络介绍
GAN网络从入门教程(一)之GAN网络介绍 稍微的开一个新坑,同样也是入门教程(因此教程的内容不会是从入门到精通,而是从入门到入土).主要是为了完成数据挖掘的课程设计,然后就把挖掘榔头挖到了GAN网络 ...
- GAN网络从入门教程(二)之GAN原理
在一篇博客GAN网络从入门教程(一)之GAN网络介绍中,简单的对GAN网络进行了一些介绍,介绍了其是什么,然后大概的流程是什么. 在这篇博客中,主要是介绍其数学公式,以及其算法流程.当然数学公式只是简 ...
- 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上
GAN网络架构分析 上图即为GAN的逻辑架构,其中的noise vector就是特征向量z,real images就是输入变量x,标签的标准比较简单(二分类么),real的就是tf.ones,fake ...
- UIContainerView纯代码实现及原理介绍
UIContainerView纯代码实现及原理介绍 1.1-在StoryBoard中使用UIContainerView 1.2-纯代码使用UIContainerView 1.3-UIContainer ...
随机推荐
- 转载 SAP用户权限控制设置及开发
创建用户SU01 事务码:SU01,用户主数据的维护,可以创建.修改.删除.锁定.解锁.修改密码等 缺省:可以设置用户的起始菜单.登录的默认语言.数字显示格式.以及日期和时间的格式设置 参数:SAP很 ...
- jquery 常用选择器基础语法学习
siblings方法的常用应用场景:选中高亮 实现代码 <!DOCTYPE html> <html> <head> <meta charset="U ...
- Android8.1 SystemUI源码分析之 Notification流程
代码流程 1.先看UI显示,StatuBar加载 CollapsedStatusBarFragment 替换 status_bar_container(状态栏通知显示区域) SystemUI\src\ ...
- 2019 DevOps 必备面试题——容器化和虚拟化
原文地址:https://medium.com/edureka/devops-interview-questions-e91a4e6ecbf3 原文作者:Saurabh Kulshrestha 翻译君 ...
- Mysql 事务及其原理
Mysql 事务及其原理 什么是事务 什么是事务?事务是作为单个逻辑工作单元执行的一系列操作,通俗易懂的说就是一组原子性的 SQL 查询.Mysql 中事务的支持在存储引擎层,MyISAM 存储引擎不 ...
- gitlab忘记密码如何重置
gitlab web登入密码忘记以后可以用如下方式修改密码shell>cd /home/git/gitlabshell> su gitshell>bundle exec rails ...
- cobbler无人值守
一.背景介绍 作为运维,在公司经常遇到一些机械性重复工作要做,例如:为新机器装系统,一台两台机器装系统,可以用光盘.U盘等介质安装,1小时也完成了,但是如果有成百台的服务器还要用光盘.U盘去安装, ...
- centos查询目标文件文件所在位置
之前有试过whereis这种语法但是查询文件不大理想.然后找到了下边这种方式可以很好的查询目标文件的位置 #在根目录 /下查找所有叫nginx的文件 find / -name nginx
- C#的语法----程序结构(1)
接下来的内容是整个C#学习的脉络,它将各个知识点串联了起来,是整个C#的重点,所以篇幅较长. 首先,我们类比一下PLC和C#执行代码的方式,其实不难发现都是顺序扫描,以Main为程序入口,从上到下一行 ...
- MyBatis核心对象之StatementHandler
MyBatis核心对象之StatementHandler StatementHandler ResultHandler ParameterHandler Executor org.apache.iba ...