Pytorch 实现 GAN 网络
Pytorch 实现 GAN 网络
原理
GAN的基本原理其实非常简单,假设我们有两个网络,G(Generator)和D(Discriminator)。它们的功能分别是:
G 是一个生成网络,它接收一个随机的噪声z,通过这个噪声生成伪造数据,记做 G(z)。
D 是一个判别网络,判别数据是不是“真实的”。它的输入参数是x,输出记为 D(x) 代表 x 为真实的概率。如果为 1 就代表 x 为真的概率是100%,而输出为 0 代表为真概率是0% 即为假。

在训练过程中,生成网络 G 的目标就是尽量生成真实的数据去欺骗判别网络D。而 D 的目标就是尽量把 G 生成的数据和真实的数据分别开来。这样,G 和 D 构成了一个动态的“博弈过程”。
最后博弈的结果是什么?在最理想的状态下,G可以生成足以“以假乱真”的数据 G(z) 。对于 D 来说,它难以判定 G 生成的数据究竟是不是真实的,因此 D(G(z)) = 0.5。
当判别器真假难辨时,D_fake,D_real->0.5,G_loss=log(1-0.5)=0.6931..., 此时 D_loss=log(1-0.5)+log(0.5)= 1.3832...
实现
这里我们的任务是:构造一个GAN网络,希望 生成器 能够输入噪声生成一个二次函数曲线
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
BATCH_SIZE = 64
G_IN_SIZE = 15 #生成器 输入尺寸
G_OUT_SIZE = 15 #生成器 输出尺寸
PAINT_POINTS = np.vstack([np.linspace(-1,1, G_OUT_SIZE) for _ in range(BATCH_SIZE)]) #shape (BATCH_SIZE, G_OUT_SIZE)
plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='Real Curve') #2 * x^2 + 1
plt.legend(loc='upper right') #标签位置
plt.show()

# 准备真实数据
def real_points():
paints = 2 * np.power(PAINT_POINTS,2) + 1
paints = torch.from_numpy(paints).float()
return paints
#定义网络
G = nn.Sequential(
nn.Linear(G_IN_SIZE,128),
nn.ReLU(),
nn.Linear(128,G_OUT_SIZE)
)
D = nn.Sequential(
nn.Linear(G_OUT_SIZE,128),
nn.ReLU(),
nn.Linear(128,1),
nn.Sigmoid() #0为False,1为True D的评估应该是在【0-1】之间的数值,所以这里采用的是Sigmod激活
)
# 优化函数
optimizer_G = torch.optim.Adam(G.parameters(),lr=0.0001)
optimizer_D = torch.optim.Adam(D.parameters(),lr=0.0001)
bceloss = nn.BCELoss()
#训练
for step in range(10001):
real_data = real_points() # 生成真实数据
# print('real_data', real_data.shape)
randn_input = torch.randn(BATCH_SIZE, G_IN_SIZE) #输入噪声
# print('randn_input', randn_input.shape)
eps = 1e-6 #防止log 0
D_real = D(real_data) # 0为False,1为True,这里输入真实数据,D_real越靠近1越好
D_fake = D(G(randn_input))
#训练判别器D,根据公式 D_loss 分为两个部分:判断真实数据 log(1-D_real);判断假数据 log(D_fake)
# D带着G一起更新,使用D(G(input))
# D_loss = -torch.mean(torch.log(eps + 1.0 - D_real) + torch.log(eps + D_fake))
D_loss = bceloss(1-D_real, torch.ones_like(D_real)) + bceloss(1-D_fake, torch.zeros_like(D_fake))
optimizer_D.zero_grad()
D_loss.backward()
optimizer_D.step()
#训练生成器G
G_fake_out = G(randn_input) # 生成器生成假数据
D_fake = D(G_fake_out) # 用判别器判别假数据,最好能让判别器判断概率趋近0.5,即生成器生成的假数据,能让判别器真假难辨
# G的损失 越接近1越好,当D_fake->0.5时,G_loss=log(1-0.5)=0.6931..., 此时 D_loss=log(1-0.5)+log(0.5)= 1.3832...
# G_loss = -torch.mean(torch.log(1.0 - D_fake + eps))
G_loss = bceloss(D_fake, torch.zeros_like(D_fake))
optimizer_G.zero_grad()
G_loss.backward() #反向
optimizer_G.step() #更新G参数
if step % 1000 == 0: # plotting
plt.cla()
plt.plot(PAINT_POINTS[0], G_fake_out.data.numpy()[0], c='#4AD631', lw=3, label='Generated Curve',)
plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='Real Curve')
plt.text(-1.0, 0.4, 'G_loss= %.3f ' % G_loss.data.numpy(), fontdict={'size': 13})
plt.text(-1.0, 0.2, 'D_loss= %.3f ' % D_loss.data.numpy(), fontdict={'size': 13})
plt.ylim((0, 3));plt.legend(loc='upper right', fontsize=10);plt.draw();plt.pause(0.1)

GAN网络的损失函数
另外的,GAN网络的损失函数也可以使用 BCELoss(Binary Cross Entropy Loss) 或 BCEWithLogitsLoss
BCELoss的解释如下:

为了解决无穷的问题,当log小于-100时,固定输出-100.
reduction 默认 mean,x 为 D预测值,y 为实际值(0, 1),则有 loss(x, y) = - mean(y * log(x) + (1-y) * log(1-x))
又因 \(x\in(0,1)\) 所以 log(x)<0, log(1-x)<0, mean前加上负号,使得 loss(x, y) > 0
由图一公式 $ \mathop{min}\limits_{G}\mathop{max}\limits_{D} V(G,D)=E_{x \sim real}[log(D(x))]+E_{z \sim noize}[log(1-D(G(z)))]$ 可知, G 的 目标越小越好, D 的 目标越大越好。

只对G网络计算loss时,G_loss 越小越好, G 网络生成的全是假数据,则x=D(G(z)), y=0,G_loss = BCELoss(x, 0) = -mean(log(1-D(G(z))))
对D网络计算loss时,D_Loss越大越好,而 \(x\in(0,1)\), 所以可以令 \(x_k=1-x\) 把越大越好问题 转化成 越小越好的问题。
输入分两种情况:
- D输入G生成的假数据时 x=D(G(z)),x1=1-D(G(z)), y1=0, D_loss_1 = BCELoss(x1, y1) = -mean(log(1-x1)) = -mean(log(D(G(z))))
- D输入真实数据时 x=D(k),x2=1-D(k), y2=1, D_loss_2 = BCELoss(x2, y2) = -mean(log(1-D(k)))
D的总损失 D_Loss = - mean(log(1-D(k)) + log(D(G(z)))) = BCELoss(1-D(G(z)), 0) + BCELoss(1-D(k), 1)
可以用 bceloss 进行替换
# G_loss = -torch.mean(torch.log(1.0 - D_fake + eps))
G_loss = bceloss(D_fake, torch.zeros_like(D_fake))
# D_loss = -torch.mean(torch.log(eps + 1.0 - D_real) + torch.log(eps + D_fake))
D_loss = bceloss(1-D_real, torch.ones_like(D_real)) + bceloss(1-D_fake, torch.zeros_like(D_fake))
Pytorch 实现 GAN 网络的更多相关文章
- pytorch训练GAN时的detach()
我最近在学使用Pytorch写GAN代码,发现有些代码在训练部分细节有略微不同,其中有的人用到了detach()函数截断梯度流,有的人没用detch(),取而代之的是在损失函数在反向传播过程中将bac ...
- MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)
版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...
- 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上
GAN网络架构分析 上图即为GAN的逻辑架构,其中的noise vector就是特征向量z,real images就是输入变量x,标签的标准比较简单(二分类么),real的就是tf.ones,fake ...
- 『TensorFlow』通过代码理解gan网络_中
『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...
- GAN网络原理介绍和代码
GAN网络的整体公式: 公式各参数介绍如下: X是真实地图片,而对应的标签是1. G(Z)是通过给定的噪声Z,生成图片(实际上是通过给定的Z生成一个tensor),对应的标签是0. D是一个二分类网络 ...
- KL散度的理解(GAN网络的优化)
原文地址Count Bayesie 这篇文章是博客Count Bayesie上的文章Kullback-Leibler Divergence Explained 的学习笔记,原文对 KL散度 的概念诠释 ...
- PyTorch对ResNet网络的实现解析
PyTorch对ResNet网络的实现解析 1.首先导入需要使用的包 import torch.nn as nn import torch.utils.model_zoo as model_zoo # ...
- GAN网络进行图片增强
GAN网络进行图片增强 基于Tensorflow框架 调用ModifyPictureSize.py文件 代码如下: from skimage import io,transform,color imp ...
- 常见的GAN网络的相关原理及推导
常见的GAN网络的相关原理及推导 在上一篇中我们给大家介绍了GAN的相关原理和推导,GAN是VAE的后一半,再加上一个鉴别网络.这样而导致了完全不同的训练方式. GAN,生成对抗网络,主要有两部分构成 ...
- GAN网络从入门教程(一)之GAN网络介绍
GAN网络从入门教程(一)之GAN网络介绍 稍微的开一个新坑,同样也是入门教程(因此教程的内容不会是从入门到精通,而是从入门到入土).主要是为了完成数据挖掘的课程设计,然后就把挖掘榔头挖到了GAN网络 ...
随机推荐
- 对比python学julia(第一章)--(第六节)数字黑洞
6.1. 问题描述 6174数字黑洞是印度数学家卡普雷卡尔于1949年发现的,又称为卡普雷卡尔黑洞,其规则描述如下. 任意取一个4位的整数(4个数字不能完全相同),把4个数字由大到小排列成一个大的数, ...
- 美国小伙: "American Guy: Only communism can save America!"
视频地址: https://www.youtube.com/watch?v=Y_WQnXFh8ss 2024大选在即,又是拜登对阵特朗普的旧日重现.在角逐谁的对手反对者更多的畸形内耗中,有一个名为 M ...
- openAI的仿真环境Gym Retro的Python API接口
如题,本文主要介绍仿真环境Gym Retro的Python API接口 . 官网地址: https://retro.readthedocs.io/en/latest/python.html ===== ...
- 美的(Midea)超声波清洗机 眼镜清洗机 超声波洗眼镜 首饰剃须刀手表假牙牙套化妆刷 洗眼镜机超声波 MXV-01 —— 工业设计上的重大问题分析
前段时间买了一个美的的超声波清洗机,就是那种超声波洗眼镜的那种,本着买个高档的可以分体的那种好清洗的原则,就在JD上千挑万选后买了下面的这个货: 链接地址: https://item.jd.com/1 ...
- 【转载】 机器学习的高维数据可视化技术(t-SNE 介绍) 外文博客原文:How t-SNE works and Dimensionality Reduction
原文地址: https://www.displayr.com/using-t-sne-to-visualize-data-before-prediction/ 该文是网上传的比较多的一个 t-SNE ...
- CF208E 题解
Blood Cousins 前置知识:线段树合并. 我们先把题目转化一下.这里先设 \(v\) 的 \(p\) 级祖先为 \(u\),事实上要求的东西就是 \(u\) 的 \(p\) 级后代的个数减 ...
- 在python项目的docker镜像里使用pdm管理依赖
前言 在 DjangoStarter 项目中,我已经使用 pdm 作为默认的包管理器,不再直接使用 pip 所以部署的时候 dockerfile 和 docker-compose 配置也得修改一下. ...
- SMU Summer 2024 Contest Round 7
SMU Summer 2024 Contest Round 7 Make Equal With Mod 题意 给定一个长度为 \(n\) 的数列 \(a\).你可以执行若干次操作,每次操作选择一个大于 ...
- SMU Summer 2024 Contest Round 2
SMU Summer 2024 Contest Round 2 Sierpinski carpet 题意 给一个整数 n ,输出对应的 \(3^n\times 3^n\) 的矩阵. 思路 \(n = ...
- PHP转Go系列 | ThinkPHP与Gin框架之打造基于WebSocket技术的消息推送中心
大家好,我是码农先森. 在早些年前客户端想要实时获取到最新消息,都是使用定时长轮询的方式,不断的从服务器上获取数据,这种粗暴的骚操作实属不雅.不过现如今我也还见有人还在一些场景下使用,比如在 PC 端 ...