GAN网络的整体公式:

公式各参数介绍如下:

X是真实地图片,而对应的标签是1。

G(Z)是通过给定的噪声Z,生成图片(实际上是通过给定的Z生成一个tensor),对应的标签是0。

D是一个二分类网络,对于给定的图片判别真假。

D和G的参数更新方式:

D通过输入的真假图片,通过BCE(二分类交叉熵)更新自己的参数。

D对G(Z)生成的标签L,G尽可能使L为true,也就是1,通过BCE(二分类交叉熵)更新自己的参数。

公式演变:

对于G来说要使D无法判别自己生成的图片是假的,故而要使G(Z)越大越好,所以就使得V(G,D)越小越好;而对于D,使G(Z)越小D(X)越大,故而使V(G,D)越大越好

为了便于求导,故而加了log,变为如下:

最后对整个batch求期望,变为如下:

基于mnist实现的GAN网络结构对应的代码

import itertools
import math
import time import torch
import torchvision
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from IPython import display
from torch.autograd import Variable
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]) train_dataset = dsets.MNIST(root='./data/', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True) class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
) def forward(self, x):
out = self.model(x.view(x.size(0), 784))
out = out.view(out.size(0), -1)
return out class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 784),
nn.Tanh()
) def forward(self, x):
x = x.view(x.size(0), -1)
out = self.model(x)
return out discriminator = Discriminator().cuda()
generator = Generator().cuda()
criterion = nn.BCELoss()
lr = 0.0002
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr) def train_discriminator(discriminator, images, real_labels, fake_images, fake_labels):
discriminator.zero_grad()
outputs = discriminator(images)
real_loss = criterion(outputs, real_labels)
real_score = outputs outputs = discriminator(fake_images)
fake_loss = criterion(outputs, fake_labels)
fake_score = outputs d_loss = real_loss + fake_loss
d_loss.backward()
d_optimizer.step()
return d_loss, real_score, fake_score
def train_generator(generator, discriminator_outputs, real_labels):
generator.zero_grad()
g_loss = criterion(discriminator_outputs, real_labels)
g_loss.backward()
g_optimizer.step()
return g_loss # draw samples from the input distribution to inspect the generation on training
num_test_samples = 16
test_noise = Variable(torch.randn(num_test_samples, 100).cuda())
# create figure for plotting
size_figure_grid = int(math.sqrt(num_test_samples))
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(6, 6))
for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
ax[i, j].get_xaxis().set_visible(False)
ax[i, j].get_yaxis().set_visible(False) # set number of epochs and initialize figure counter
num_epochs = 200
num_batches = len(train_loader)
num_fig = 0 for epoch in range(num_epochs):
for n, (images, _) in enumerate(train_loader):
images = Variable(images.cuda())
real_labels = Variable(torch.ones(images.size(0)).cuda()) # Sample from generator
noise = Variable(torch.randn(images.size(0), 100).cuda())
fake_images = generator(noise)
fake_labels = Variable(torch.zeros(images.size(0)).cuda()) # Train the discriminator
d_loss, real_score, fake_score = train_discriminator(discriminator, images, real_labels, fake_images,
fake_labels) # Sample again from the generator and get output from discriminator
noise = Variable(torch.randn(images.size(0), 100).cuda())
fake_images = generator(noise)
outputs = discriminator(fake_images) # Train the generator
g_loss = train_generator(generator, outputs, real_labels) if (n + 1) % 100 == 0:
test_images = generator(test_noise) for k in range(num_test_samples):
i = k // 4
j = k % 4
ax[i, j].cla()
ax[i, j].imshow(test_images[k, :].data.cpu().numpy().reshape(28, 28), cmap='Greys')
display.clear_output(wait=True)
display.display(plt.gcf()) plt.savefig('results/mnist-gan-%03d.png' % num_fig)
num_fig += 1
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
'D(x): %.2f, D(G(z)): %.2f'
% (epoch + 1, num_epochs, n + 1, num_batches, d_loss.data[0], g_loss.data[0],
real_score.data.mean(), fake_score.data.mean())) fig.close()

GAN网络原理介绍和代码的更多相关文章

  1. When I see you again(加密原理介绍,代码实现DES、AES、RSA、Base64、MD5)

    关于网络安全的数据加密部分,本来打算总结一篇博客搞定,没想到东西太多,这已是第三篇了,而且这篇写了多次,熬了多次夜,真是again and again.起个名字:数据加密三部曲,前两部链接如下: 整体 ...

  2. 加密原理介绍,代码实现DES、AES、RSA、Base64、MD5

    阅读目录 github下载地址 一.DES对称加密 二.AES对称加密 三.RSA非对称加密 四.实际使用 五.关于Padding 关于电脑终端Openssl加密解密命令 关于网络安全的数据加密部分, ...

  3. TF实战:(Mask R-CNN原理介绍与代码实现)-Chapter-8

    二值掩膜输出依据种类预测分支(Faster R-CNN部分)预测结果:当前RoI的物体种类为i第i个二值掩膜输出就是该RoI的损失Lmask 对于预测的二值掩膜输出,我们对每个像素点应用sigmoid ...

  4. 『TensorFlow』通过代码理解gan网络_中

    『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...

  5. 常见的GAN网络的相关原理及推导

    常见的GAN网络的相关原理及推导 在上一篇中我们给大家介绍了GAN的相关原理和推导,GAN是VAE的后一半,再加上一个鉴别网络.这样而导致了完全不同的训练方式. GAN,生成对抗网络,主要有两部分构成 ...

  6. GAN网络从入门教程(一)之GAN网络介绍

    GAN网络从入门教程(一)之GAN网络介绍 稍微的开一个新坑,同样也是入门教程(因此教程的内容不会是从入门到精通,而是从入门到入土).主要是为了完成数据挖掘的课程设计,然后就把挖掘榔头挖到了GAN网络 ...

  7. GAN网络从入门教程(二)之GAN原理

    在一篇博客GAN网络从入门教程(一)之GAN网络介绍中,简单的对GAN网络进行了一些介绍,介绍了其是什么,然后大概的流程是什么. 在这篇博客中,主要是介绍其数学公式,以及其算法流程.当然数学公式只是简 ...

  8. 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上

    GAN网络架构分析 上图即为GAN的逻辑架构,其中的noise vector就是特征向量z,real images就是输入变量x,标签的标准比较简单(二分类么),real的就是tf.ones,fake ...

  9. UIContainerView纯代码实现及原理介绍

    UIContainerView纯代码实现及原理介绍 1.1-在StoryBoard中使用UIContainerView 1.2-纯代码使用UIContainerView 1.3-UIContainer ...

随机推荐

  1. 精通awk系列(7):awk读取行的细节

    回到: Linux系列文章 Shell系列文章 Awk系列文章 详细分析awk如何读取文件 awk读取输入文件时,每次读取一条记录(record)(默认情况下按行读取,所以此时记录就是行).每读取一条 ...

  2. 某酒店建筑设计CAD施工图

    本素材为某酒店建筑设计CAD施工图,其中包涵的有酒店室内装修图纸.各个标间房屋改造图以及酒店场外建筑施工图.其中图纸的格式都是为dwg格式的.想要查看图纸就可以使用CAD看图软件来进行查看.以下就是一 ...

  3. React 基础笔记

    概览 React 是一个声明式,高效且灵活的用于构建用户界面的 JavaScript库.可以将一些简短.独立的代码片段组合成复杂的UI界面,这些片段被称为"组件". React 大 ...

  4. Git 在同一台机器上配置多个Git帐号

    在同一台机器上配置多个Git帐号 By:授客 QQ:1033553122 实践环境 win10 Git-2.21.0-64-bit.exe TortoiseGit-2.8.0.0-64bit.msi ...

  5. elasticsearch对无意义的词进行屏蔽——停用词

    介绍 在使用elasticsearch进行搜索业务的时候,发现一篇和搜索关键字完全不匹配的文章排在最前面.打开它发现原来是这篇文章含有非常多的"的"这个无意义的词.而我的搜索关键字 ...

  6. VC遍历访问目录下的文件

    访问目录文件夹下的文件是经常需要的操作,C/C++和win32接口都没有提供直接调用的函数.在这里总结了几个经常用到的函数,通过MFC的CFileFind函数递归遍历实现,包括以下几个功能函数: 查找 ...

  7. Kali Linux configuration "Ettercap"

    Xx_Instroduction Ettercap is a man-in-the-middle attack(MITM) tool,kali take this tool,so,use front ...

  8. Docker运行dotnetcore

                    windows下安装docker 参考: https://www.jianshu.com/p/502b4ac536ef https://docs.docker.com/ ...

  9. c++ 的namespace及注意事项

    前文 下文中的出现的"当前域"为"当前作用域"的简写 namepsace在c++中是用来避免不同模块下相同名字冲突的一种关键字,本文粗略的介绍了一下namesp ...

  10. 划分为k个相等的子集

    给定一个整数数组  nums 和一个正整数 k,找出是否有可能把这个数组分成 k 个非空子集,其总和都相等. 示例 1: 输入: nums = [4, 3, 2, 3, 5, 2, 1], k = 4 ...