GAN网络的整体公式:

公式各参数介绍如下:

X是真实地图片,而对应的标签是1。

G(Z)是通过给定的噪声Z,生成图片(实际上是通过给定的Z生成一个tensor),对应的标签是0。

D是一个二分类网络,对于给定的图片判别真假。

D和G的参数更新方式:

D通过输入的真假图片,通过BCE(二分类交叉熵)更新自己的参数。

D对G(Z)生成的标签L,G尽可能使L为true,也就是1,通过BCE(二分类交叉熵)更新自己的参数。

公式演变:

对于G来说要使D无法判别自己生成的图片是假的,故而要使G(Z)越大越好,所以就使得V(G,D)越小越好;而对于D,使G(Z)越小D(X)越大,故而使V(G,D)越大越好

为了便于求导,故而加了log,变为如下:

最后对整个batch求期望,变为如下:

基于mnist实现的GAN网络结构对应的代码

import itertools
import math
import time import torch
import torchvision
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from IPython import display
from torch.autograd import Variable
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]) train_dataset = dsets.MNIST(root='./data/', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True) class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
) def forward(self, x):
out = self.model(x.view(x.size(0), 784))
out = out.view(out.size(0), -1)
return out class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 784),
nn.Tanh()
) def forward(self, x):
x = x.view(x.size(0), -1)
out = self.model(x)
return out discriminator = Discriminator().cuda()
generator = Generator().cuda()
criterion = nn.BCELoss()
lr = 0.0002
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr) def train_discriminator(discriminator, images, real_labels, fake_images, fake_labels):
discriminator.zero_grad()
outputs = discriminator(images)
real_loss = criterion(outputs, real_labels)
real_score = outputs outputs = discriminator(fake_images)
fake_loss = criterion(outputs, fake_labels)
fake_score = outputs d_loss = real_loss + fake_loss
d_loss.backward()
d_optimizer.step()
return d_loss, real_score, fake_score
def train_generator(generator, discriminator_outputs, real_labels):
generator.zero_grad()
g_loss = criterion(discriminator_outputs, real_labels)
g_loss.backward()
g_optimizer.step()
return g_loss # draw samples from the input distribution to inspect the generation on training
num_test_samples = 16
test_noise = Variable(torch.randn(num_test_samples, 100).cuda())
# create figure for plotting
size_figure_grid = int(math.sqrt(num_test_samples))
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(6, 6))
for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
ax[i, j].get_xaxis().set_visible(False)
ax[i, j].get_yaxis().set_visible(False) # set number of epochs and initialize figure counter
num_epochs = 200
num_batches = len(train_loader)
num_fig = 0 for epoch in range(num_epochs):
for n, (images, _) in enumerate(train_loader):
images = Variable(images.cuda())
real_labels = Variable(torch.ones(images.size(0)).cuda()) # Sample from generator
noise = Variable(torch.randn(images.size(0), 100).cuda())
fake_images = generator(noise)
fake_labels = Variable(torch.zeros(images.size(0)).cuda()) # Train the discriminator
d_loss, real_score, fake_score = train_discriminator(discriminator, images, real_labels, fake_images,
fake_labels) # Sample again from the generator and get output from discriminator
noise = Variable(torch.randn(images.size(0), 100).cuda())
fake_images = generator(noise)
outputs = discriminator(fake_images) # Train the generator
g_loss = train_generator(generator, outputs, real_labels) if (n + 1) % 100 == 0:
test_images = generator(test_noise) for k in range(num_test_samples):
i = k // 4
j = k % 4
ax[i, j].cla()
ax[i, j].imshow(test_images[k, :].data.cpu().numpy().reshape(28, 28), cmap='Greys')
display.clear_output(wait=True)
display.display(plt.gcf()) plt.savefig('results/mnist-gan-%03d.png' % num_fig)
num_fig += 1
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
'D(x): %.2f, D(G(z)): %.2f'
% (epoch + 1, num_epochs, n + 1, num_batches, d_loss.data[0], g_loss.data[0],
real_score.data.mean(), fake_score.data.mean())) fig.close()

GAN网络原理介绍和代码的更多相关文章

  1. When I see you again(加密原理介绍,代码实现DES、AES、RSA、Base64、MD5)

    关于网络安全的数据加密部分,本来打算总结一篇博客搞定,没想到东西太多,这已是第三篇了,而且这篇写了多次,熬了多次夜,真是again and again.起个名字:数据加密三部曲,前两部链接如下: 整体 ...

  2. 加密原理介绍,代码实现DES、AES、RSA、Base64、MD5

    阅读目录 github下载地址 一.DES对称加密 二.AES对称加密 三.RSA非对称加密 四.实际使用 五.关于Padding 关于电脑终端Openssl加密解密命令 关于网络安全的数据加密部分, ...

  3. TF实战:(Mask R-CNN原理介绍与代码实现)-Chapter-8

    二值掩膜输出依据种类预测分支(Faster R-CNN部分)预测结果:当前RoI的物体种类为i第i个二值掩膜输出就是该RoI的损失Lmask 对于预测的二值掩膜输出,我们对每个像素点应用sigmoid ...

  4. 『TensorFlow』通过代码理解gan网络_中

    『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...

  5. 常见的GAN网络的相关原理及推导

    常见的GAN网络的相关原理及推导 在上一篇中我们给大家介绍了GAN的相关原理和推导,GAN是VAE的后一半,再加上一个鉴别网络.这样而导致了完全不同的训练方式. GAN,生成对抗网络,主要有两部分构成 ...

  6. GAN网络从入门教程(一)之GAN网络介绍

    GAN网络从入门教程(一)之GAN网络介绍 稍微的开一个新坑,同样也是入门教程(因此教程的内容不会是从入门到精通,而是从入门到入土).主要是为了完成数据挖掘的课程设计,然后就把挖掘榔头挖到了GAN网络 ...

  7. GAN网络从入门教程(二)之GAN原理

    在一篇博客GAN网络从入门教程(一)之GAN网络介绍中,简单的对GAN网络进行了一些介绍,介绍了其是什么,然后大概的流程是什么. 在这篇博客中,主要是介绍其数学公式,以及其算法流程.当然数学公式只是简 ...

  8. 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上

    GAN网络架构分析 上图即为GAN的逻辑架构,其中的noise vector就是特征向量z,real images就是输入变量x,标签的标准比较简单(二分类么),real的就是tf.ones,fake ...

  9. UIContainerView纯代码实现及原理介绍

    UIContainerView纯代码实现及原理介绍 1.1-在StoryBoard中使用UIContainerView 1.2-纯代码使用UIContainerView 1.3-UIContainer ...

随机推荐

  1. NSwag.AspNetCore常用功能介绍

    对于asp.net core 下的Swagger,之前一直用Swashbuckle的,因为官方推荐,再加上有老张的博客助力<从壹开始前后端分离[ .NET Core2.0/3.0 +Vue2.0 ...

  2. QGIS练手 - 数据

    又熬夜了... 这篇博客可能会将QGIS数据管理部分和ArcGIS数据管理进行对比学习. 1. 本地数据文件与数据库(矢量) 1.1 文件 QGIS用的是shp文件.kml文件.geojson文件较多 ...

  3. Android8.1 SystemUI源码分析之 电池时钟刷新

    SystemUI源码分析相关文章 Android8.1 SystemUI源码分析之 Notification流程 分析之前再贴一下 StatusBar 相关类图 电池图标刷新 从上篇的分析得到电池图标 ...

  4. Tornado—添加请求头允许跨域请求访问

    跨域请求访问 如果是前后端分离,那就肯定会遇到cros跨域请求难题,可以设置一个BaseHandler,然后继承即可. class BaseHandler(tornado.web.RequestHan ...

  5. docker工具之基本命令

    docker工具之基本命令 1.docker服务的启动.停止.重启 systemctl start docker #启动docker服务 systemctl daemon-reload #守护进程重启 ...

  6. 11. Go 语言网络编程

    Go 语言网络编程 Go语言在编写 web 应用方面非常得力.因为目前它还没有 GUI(Graphic User Interface 图形化用户界面)的框架,通过文本或者模板展现的 html 界面是目 ...

  7. 剑指Offer-39.把数组排成最小的数(C++/Java)

    题目: 输入一个正整数数组,把数组里所有数字拼接起来排成一个数,打印能拼接出的所有数字中最小的一个.例如输入数组{3,32,321},则打印出这三个数字能排成的最小数字为321323. 分析: 将数组 ...

  8. Swoole如何处理高并发

    有需要学习交流的友人请加入swoole交流群的咱们一起,有问题一起交流,一起进步!前提是你是学技术的.感谢阅读! 点此加入该群 swoole如何处理高并发 ①Reactor模型介绍 IO复用异步非阻塞 ...

  9. 性能分析-java程序篇之案例-工具和方法

    1. 背景说明 线上服务响应时间超过40秒,登录服务器发现cpu将近100%了(如下图),针对此问题,本文说明排查过程.工具以定位具体的原因. 2. 分析排查过程 此类问题的排查,有两款神器可用,分别 ...

  10. Java基础(七)

    左连接,右连接,内连接,全连接的区别 左连接:返回左表所有行,右表没有匹配行则返回null 右连接:返回右表所有行,左表没有匹配行则返回null 内连接:返回左右表共有行 全连接:返回左右表所有行,无 ...