Pytorch 实现 GAN 网络

原理

GAN的基本原理其实非常简单,假设我们有两个网络,G(Generator)和D(Discriminator)。它们的功能分别是:

G 是一个生成网络,它接收一个随机的噪声z,通过这个噪声生成伪造数据,记做 G(z)。

D 是一个判别网络,判别数据是不是“真实的”。它的输入参数是x,输出记为 D(x) 代表 x 为真实的概率。如果为 1 就代表 x 为真的概率是100%,而输出为 0 代表为真概率是0% 即为假。

在训练过程中,生成网络 G 的目标就是尽量生成真实的数据去欺骗判别网络D。而 D 的目标就是尽量把 G 生成的数据和真实的数据分别开来。这样,G 和 D 构成了一个动态的“博弈过程”。

最后博弈的结果是什么?在最理想的状态下,G可以生成足以“以假乱真”的数据 G(z) 。对于 D 来说,它难以判定 G 生成的数据究竟是不是真实的,因此 D(G(z)) = 0.5。

当判别器真假难辨时,D_fake,D_real->0.5,G_loss=log(1-0.5)=0.6931..., 此时 D_loss=log(1-0.5)+log(0.5)= 1.3832...

实现

这里我们的任务是:构造一个GAN网络,希望 生成器 能够输入噪声生成一个二次函数曲线

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np BATCH_SIZE = 64
G_IN_SIZE = 15 #生成器 输入尺寸
G_OUT_SIZE = 15 #生成器 输出尺寸 PAINT_POINTS = np.vstack([np.linspace(-1,1, G_OUT_SIZE) for _ in range(BATCH_SIZE)]) #shape (BATCH_SIZE, G_OUT_SIZE) plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='Real Curve') #2 * x^2 + 1
plt.legend(loc='upper right') #标签位置
plt.show()

# 准备真实数据
def real_points():
paints = 2 * np.power(PAINT_POINTS,2) + 1
paints = torch.from_numpy(paints).float()
return paints #定义网络
G = nn.Sequential(
nn.Linear(G_IN_SIZE,128),
nn.ReLU(),
nn.Linear(128,G_OUT_SIZE)
) D = nn.Sequential(
nn.Linear(G_OUT_SIZE,128),
nn.ReLU(),
nn.Linear(128,1),
nn.Sigmoid() #0为False,1为True D的评估应该是在【0-1】之间的数值,所以这里采用的是Sigmod激活
) # 优化函数
optimizer_G = torch.optim.Adam(G.parameters(),lr=0.0001)
optimizer_D = torch.optim.Adam(D.parameters(),lr=0.0001) bceloss = nn.BCELoss() #训练
for step in range(10001):
real_data = real_points() # 生成真实数据
# print('real_data', real_data.shape)
randn_input = torch.randn(BATCH_SIZE, G_IN_SIZE) #输入噪声
# print('randn_input', randn_input.shape) eps = 1e-6 #防止log 0 D_real = D(real_data) # 0为False,1为True,这里输入真实数据,D_real越靠近1越好
D_fake = D(G(randn_input)) #训练判别器D,根据公式 D_loss 分为两个部分:判断真实数据 log(1-D_real);判断假数据 log(D_fake)
# D带着G一起更新,使用D(G(input))
# D_loss = -torch.mean(torch.log(eps + 1.0 - D_real) + torch.log(eps + D_fake))
D_loss = bceloss(1-D_real, torch.ones_like(D_real)) + bceloss(1-D_fake, torch.zeros_like(D_fake)) optimizer_D.zero_grad()
D_loss.backward()
optimizer_D.step() #训练生成器G
G_fake_out = G(randn_input) # 生成器生成假数据
D_fake = D(G_fake_out) # 用判别器判别假数据,最好能让判别器判断概率趋近0.5,即生成器生成的假数据,能让判别器真假难辨
# G的损失 越接近1越好,当D_fake->0.5时,G_loss=log(1-0.5)=0.6931..., 此时 D_loss=log(1-0.5)+log(0.5)= 1.3832...
# G_loss = -torch.mean(torch.log(1.0 - D_fake + eps))
G_loss = bceloss(D_fake, torch.zeros_like(D_fake)) optimizer_G.zero_grad()
G_loss.backward() #反向
optimizer_G.step() #更新G参数 if step % 1000 == 0: # plotting
plt.cla()
plt.plot(PAINT_POINTS[0], G_fake_out.data.numpy()[0], c='#4AD631', lw=3, label='Generated Curve',)
plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='Real Curve')
plt.text(-1.0, 0.4, 'G_loss= %.3f ' % G_loss.data.numpy(), fontdict={'size': 13})
plt.text(-1.0, 0.2, 'D_loss= %.3f ' % D_loss.data.numpy(), fontdict={'size': 13})
plt.ylim((0, 3));plt.legend(loc='upper right', fontsize=10);plt.draw();plt.pause(0.1)

GAN网络的损失函数

另外的,GAN网络的损失函数也可以使用 BCELoss(Binary Cross Entropy Loss) 或 BCEWithLogitsLoss

BCELoss的解释如下:

为了解决无穷的问题,当log小于-100时,固定输出-100.

reduction 默认 mean,x 为 D预测值,y 为实际值(0, 1),则有 loss(x, y) = - mean(y * log(x) + (1-y) * log(1-x))

又因 \(x\in(0,1)\) 所以 log(x)<0, log(1-x)<0, mean前加上负号,使得 loss(x, y) > 0

由图一公式 $ \mathop{min}\limits_{G}\mathop{max}\limits_{D} V(G,D)=E_{x \sim real}[log(D(x))]+E_{z \sim noize}[log(1-D(G(z)))]$ 可知, G 的 目标越小越好, D 的 目标越大越好。

只对G网络计算loss时,G_loss 越小越好, G 网络生成的全是假数据,则x=D(G(z)), y=0,G_loss = BCELoss(x, 0) = -mean(log(1-D(G(z))))

对D网络计算loss时,D_Loss越大越好,而 \(x\in(0,1)\), 所以可以令 \(x_k=1-x\) 把越大越好问题 转化成 越小越好的问题。

输入分两种情况:

  • D输入G生成的假数据时 x=D(G(z)),x1=1-D(G(z)), y1=0, D_loss_1 = BCELoss(x1, y1) = -mean(log(1-x1)) = -mean(log(D(G(z))))
  • D输入真实数据时 x=D(k),x2=1-D(k), y2=1, D_loss_2 = BCELoss(x2, y2) = -mean(log(1-D(k)))

D的总损失 D_Loss = - mean(log(1-D(k)) + log(D(G(z)))) = BCELoss(1-D(G(z)), 0) + BCELoss(1-D(k), 1)

可以用 bceloss 进行替换

# G_loss = -torch.mean(torch.log(1.0 - D_fake + eps))
G_loss = bceloss(D_fake, torch.zeros_like(D_fake)) # D_loss = -torch.mean(torch.log(eps + 1.0 - D_real) + torch.log(eps + D_fake))
D_loss = bceloss(1-D_real, torch.ones_like(D_real)) + bceloss(1-D_fake, torch.zeros_like(D_fake))

Pytorch 实现 GAN 网络的更多相关文章

  1. pytorch训练GAN时的detach()

    我最近在学使用Pytorch写GAN代码,发现有些代码在训练部分细节有略微不同,其中有的人用到了detach()函数截断梯度流,有的人没用detch(),取而代之的是在损失函数在反向传播过程中将bac ...

  2. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  3. 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上

    GAN网络架构分析 上图即为GAN的逻辑架构,其中的noise vector就是特征向量z,real images就是输入变量x,标签的标准比较简单(二分类么),real的就是tf.ones,fake ...

  4. 『TensorFlow』通过代码理解gan网络_中

    『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...

  5. GAN网络原理介绍和代码

    GAN网络的整体公式: 公式各参数介绍如下: X是真实地图片,而对应的标签是1. G(Z)是通过给定的噪声Z,生成图片(实际上是通过给定的Z生成一个tensor),对应的标签是0. D是一个二分类网络 ...

  6. KL散度的理解(GAN网络的优化)

    原文地址Count Bayesie 这篇文章是博客Count Bayesie上的文章Kullback-Leibler Divergence Explained 的学习笔记,原文对 KL散度 的概念诠释 ...

  7. PyTorch对ResNet网络的实现解析

    PyTorch对ResNet网络的实现解析 1.首先导入需要使用的包 import torch.nn as nn import torch.utils.model_zoo as model_zoo # ...

  8. GAN网络进行图片增强

    GAN网络进行图片增强 基于Tensorflow框架 调用ModifyPictureSize.py文件 代码如下: from skimage import io,transform,color imp ...

  9. 常见的GAN网络的相关原理及推导

    常见的GAN网络的相关原理及推导 在上一篇中我们给大家介绍了GAN的相关原理和推导,GAN是VAE的后一半,再加上一个鉴别网络.这样而导致了完全不同的训练方式. GAN,生成对抗网络,主要有两部分构成 ...

  10. GAN网络从入门教程(一)之GAN网络介绍

    GAN网络从入门教程(一)之GAN网络介绍 稍微的开一个新坑,同样也是入门教程(因此教程的内容不会是从入门到精通,而是从入门到入土).主要是为了完成数据挖掘的课程设计,然后就把挖掘榔头挖到了GAN网络 ...

随机推荐

  1. matplotlib中渐变颜色条转CSS样式(hex格式)——同mapbox中cog的颜色条拉伸显示

    matplotlib中渐变颜色条转CSS样式(hex格式)--同mapbox中cog的颜色条拉伸显示 应用场景: 1.适用于mapbox中显示cog影像时,colormap_name拉伸颜色条转换 2 ...

  2. web3 产品介绍 Dune Analytics 区块链的数据探索和可视化 链上热点和趋势一手掌握

    Dune Analytics 是一个强大的数据分析平台,旨在帮助用户在区块链上进行数据探索和可视化. Dune Analytics的特点: 数据查询与可视化:Dune Analytics允许用户从多个 ...

  3. Tomcat日志信息有乱码的处理方法

    1.问题描述 1.1.Idea中的tomcat日志有乱码 1.2.直接启动tomcat的日志有乱码 1.3.原因分析 问题是由于tomcat使用的编码和操作系统使用的编码不一致导致: Windows1 ...

  4. 【Java-GUI】04 菜单

    --1.菜单组件 相关对象: MenuBar 菜单条 Menu 菜单容器 PopupMenu 上下文菜单(右键弹出菜单组件) MenuItem 菜单项 CheckboxMenuItem 复选框菜单项 ...

  5. Parallel and Sequential Data Structures and Algorithms

    并串行 从零开始考前突击并串行数据结构与算法 强烈建议和原教材参照着看 Introduction 本书的要点 定义问题 不同的算法解决 设计抽象数据类型和相应的数据结构实现 分析比较算法和数据类型的代 ...

  6. 如何让你的C语言程序打印的log多一点色彩?(超级实用)

    接着上一篇文章<由字节对齐引发的一场"血案" > 在平常的调试中,printf字体格式与颜色都是默认一致的. 如果可以根据log信息的重要程度,配以不同的颜色与格式,可 ...

  7. 5 个有趣的 Python 开源项目「GitHub 热点速览」

    本期,我从上周的开源热搜项目中精心挑选了 5 个有趣.好玩的 Python 开源项目. 首先是 PyScript,它可以让你直接在浏览器中运行 Python 代码,不仅支持在 HTML 中嵌入,还能安 ...

  8. manim边学边做--直线类

    直线是最常用的二维结构,也是构造其他二维图形的基础.manim中针对线性结构提供了很多模块,本篇主要介绍常用的几个直线类的模块. Line:通用直线 DashedLine:各种类型的虚线 Tangen ...

  9. 微信小程序之无需服务端支持实现内容安全检查

    微信小程序之无需服务端支持实现内容安全检查 微信小程序审核未通过,原因如下: 为避免您的小程序被滥用,请你完善内容审核机制,如调用小程序内容安全API,或使用其他技术.人工审核手段,过滤色情.违法等有 ...

  10. spring boot 若依系统整合Ueditor,部署时候上传图片错误解决

    spring boot 若依系统整合Ueditor,部署时候上传图片错误解决 前言:国庆假期找了个ruoyi版本的cms玩玩,从git上看,介绍如下图: 后台部分截图: 编辑 ​ 编辑 ​ 编辑 ​ ...