Generative Adversarial Network (GAN) - Pytorch版
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image # 配置GPU或CPU设置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 超参数设置
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples' # Create a directory if not exists
if not os.path.exists(sample_dir):
os.makedirs(sample_dir) # Pytorch:transforms的二十二个方法:https://blog.csdn.net/weixin_38533896/article/details/86028509#10transformsNormalize_120
# 对Image数据按通道进行标准化,即先减均值,再除以标准差,注意是 hwc
transform = transforms.Compose([
transforms.ToTensor(),# 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1],归一化至[0-1]是直接除以255
transforms.Normalize(mean=(0.5, 0.5, 0.5), # 3 for RGB channels
std=(0.5, 0.5, 0.5))]) # 下载数据,并指定转换形式transform
# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./data/',
train=True,
transform=transform,
download=True)
# 数据加载,按照batch_size大小加载,并随机打乱
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,
batch_size=batch_size,
shuffle=True)
# 鉴别器
# Discriminator
D = nn.Sequential(
nn.Linear(image_size, hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size, hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size, 1),
nn.Sigmoid())
# 生成器
# Generator
G = nn.Sequential(
nn.Linear(latent_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, image_size),
nn.Tanh()) # GPU或CPU设置
# Device setting
D = D.to(device)
print(D)
# Sequential((0): Linear(in_features=784, out_features=256, bias=True)
# (1): LeakyReLU(negative_slope=0.2)
# (2): Linear(in_features=256, out_features=256, bias=True)
# (3): LeakyReLU(negative_slope=0.2)
# (4): Linear(in_features=256, out_features=1, bias=True)
# (5): Sigmoid())
G = G.to(device)
print(G)
# Sequential( (0): Linear(in_features=64, out_features=256, bias=True)
# (1): ReLU()
# (2): Linear(in_features=256, out_features=256, bias=True)
# (3): ReLU()
# (4): Linear(in_features=256, out_features=784, bias=True)
# (5): Tanh()) # 二值交叉熵损失函数和优化器设置
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
# 优化器设置 ,并传入鉴别器与生成器模型参数和相应的学习率
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002) # 规范化处理
def denorm(x):
out = (x + 1) / 2
return out.clamp(0, 1) # 将out张量每个元素的范围限制到区间 [min,max] # 清空上一步的残余更新参数值
def reset_grad():
d_optimizer.zero_grad() # 清空鉴别器的梯度器上一步的残余更新参数值
g_optimizer.zero_grad() # 清空生成器的梯度器上一步的残余更新参数值 # 开始训练
total_step = len(data_loader)
for epoch in range(num_epochs):
for i, (images, _) in enumerate(data_loader):
images = images.reshape(batch_size, -1).to(device) # 创建label
# Create the labels which are later used as input for the BCE loss
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device) # ================================================================== #
# 训练鉴别器 #
# ================================================================== #
# 使用真实图像计算二值交叉熵损失
# Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
# Second term of the loss is always zero since real_labels == 1
outputs = D(images)# 真图像输入给鉴别器,并产生鉴别器输出
d_loss_real = criterion(outputs, real_labels) # 计算由真图像输入给鉴别器产生的输出与真实的label间的二值交叉熵损失
real_score = outputs# 鉴别器输出真实图像score值 # Compute BCELoss using fake images
# First term of the loss is always zero since fake_labels == 0
z = torch.randn(batch_size, latent_size).to(device)# 随机生成假图像
fake_images = G(z)# 假图像输入给生成器,并产生生成器输出假值图
outputs = D(fake_images)# 生成器输出假值图给鉴别器鉴别,输出鉴别结果
d_loss_fake = criterion(outputs, fake_labels)# 由随机产生的假图像输入给生成器产生的假图,计算生成器生成的假图输入给鉴别器鉴别输出与假的标签间的二值交叉熵损失
fake_score = outputs# 鉴别器输出假图像score值 # 反向传播与优化
d_loss = d_loss_real + d_loss_fake#真图像输入给鉴别器产生的输出与真实的label间的二值交叉熵损失和假图输入给鉴别器鉴别输出与假的标签间的二值交叉熵损失
# 重置梯度求解器
reset_grad()
# 反向传播
d_loss.backward()
# 将参数更新值施加到鉴别器 model的parameters上
d_optimizer.step() # ================================================================== #
# 训练生成器 #
# ================================================================== #
# 计算假图像的损失
# Compute loss with fake images
z = torch.randn(batch_size, latent_size).to(device)# 随机生成假图像
fake_images = G(z)# 假图像输入给生成器,并产生生成器输出假值图
outputs = D(fake_images)# 生成器输出假值图给鉴别器鉴别,输出鉴别结果 # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
# For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
g_loss = criterion(outputs, real_labels)# 由随机产生的假图像输入给生成器产生的假图,计算生成器生成的假图输入给鉴别器鉴别输出与真的标签间的二值交叉熵损失 # 反向传播与优化
# 重置梯度求解器
reset_grad()
# 反向传播
g_loss.backward()
# 将参数更新值施加到生成器 model的parameters上
g_optimizer.step()
# 每迭代一定步骤,打印结果值
if (i + 1) % 200 == 0:
print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
.format(epoch, num_epochs, i + 1, total_step, d_loss.item(), g_loss.item(),
real_score.mean().item(), fake_score.mean().item()))
# 保存真图像
# Save real images
if (epoch + 1) == 1:
images = images.reshape(images.size(0), 1, 28, 28)
save_image(denorm(images), os.path.join(sample_dir, 'real_images.png')) # 保存假或采样图像
# Save sampled images
fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch + 1))) # 保存以训练好的生成器与鉴别器模型
# Save the model checkpoints
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')
Generative Adversarial Network (GAN) - Pytorch版的更多相关文章
- GAN Generative Adversarial Network 生成式对抗网络-相关内容
参考: https://baijiahao.baidu.com/s?id=1568663805038898&wfr=spider&for=pc Generative Adversari ...
- Face Aging with Conditional Generative Adversarial Network 论文笔记
Face Aging with Conditional Generative Adversarial Network 论文笔记 2017.02.28 Motivation: 本文是要根据最新的条件产 ...
- 生成对抗网络(Generative Adversarial Network)阅读笔记
笔记持续更新中,请大家耐心等待 首先需要大概了解什么是生成对抗网络,参考维基百科给出的定义(https://zh.wikipedia.org/wiki/生成对抗网络): 生成对抗网络(英语:Gener ...
- ASRWGAN: Wasserstein Generative Adversarial Network for Audio Super Resolution
ASEGAN:WGAN音频超分辨率 这篇文章并不具有权威性,因为没有发表,说不定是外国的某个大学的毕业设计,或者课程结束后的作业.或者实验报告. CS230: Deep Learning, Sprin ...
- Speech Super Resolution Generative Adversarial Network
博客作者:凌逆战 博客地址:https://www.cnblogs.com/LXP-Never/p/10874993.html 论文作者:Sefik Emre Eskimez , Kazuhito K ...
- DeepPrivacy: A Generative Adversarial Network for Face Anonymization阅读笔记
DeepPrivacy: A Generative Adversarial Network for Face Anonymization ISVC 2019 https://arxiv.org/pdf ...
- 论文阅读之:Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 2016.10.23 摘要: ...
- 论文阅读:Single Image Dehazing via Conditional Generative Adversarial Network
Single Image Dehazing via Conditional Generative Adversarial Network Runde Li∗ Jinshan Pan∗ Zechao L ...
- 一文读懂对抗生成学习(Generative Adversarial Nets)[GAN]
一文读懂对抗生成学习(Generative Adversarial Nets)[GAN] 0x00 推荐论文 https://arxiv.org/pdf/1406.2661.pdf 0x01什么是ga ...
随机推荐
- P1041 传染病控制——暴力遍历所有相同深度的节点
P1041 传染病控制 说实话这种暴力我还是头一次见,每次病毒都会往下传染一层: 数据范围小,我们可以直接枚举当前层保护谁就好了: 用vector 记录相同层数的节点:维护已经断了的点: 如果超出最底 ...
- 谈下python的GIL
GIL 是python的全局解释器锁,同一进程中假如有多个线程运行,一个线程在运行python程序的时候会霸占python解释器(加了一把锁即GIL),使该进程内的其他线程无法运行,等该线程运行完后其 ...
- SQLServer 使用自定义端口连接的方法(转载)
使用过SQL Server的人大多都知道,SQL Server服务器默认监听的端口号是1433,但是我今天遇到的问题是我的机器上有三个数据库实例,这样使用TCP/IP远程连接时就产生了问题.如何在Mi ...
- [ZJOI2007][BZOJ1060]时态同步
Description 小Q在电子工艺实习课上学习焊接电路板.一块电路板由若干个元件组成,我们不妨称之为节点,并将其用数 字1,2,3….进行标号.电路板的各个节点由若干不相交的导线相连接,且对于电路 ...
- [WEB安全]PHP伪协议总结
0x01 简介 首先来看一下有哪些文件包含函数: include.require.include_once.require_once.highlight_file show_source .readf ...
- 从一个表中往另外一个表中插入数据用到的SQL
insert into jdjc_zzjcxm (zj,jcxmmc) select sys_guid(),zbmc from JDJC_WHJXXMMC;
- com.ibm.db2.jcc.am.SqlSyntaxErrorException: DB2 SQL Error: SQLCODE=-418, SQLSTATE=42610, SQLERRMC=null
写了一条sql,在db2数据库中可以执行,但是转换成mybatis的mapper文件后,在执行排序操作时报该错误. 我排序是这样写的 <if test="orderStr != nul ...
- python 椭球面
作者:chaowei wu链接:https://www.zhihu.com/question/266366089/answer/307037017来源:知乎著作权归作者所有.商业转载请联系作者获得授权 ...
- 商城怎么使用ajax?
1.前端: Ajax.call('order.php?act=export', params, function (data) { document.getElementById("expo ...
- JEECG Hibernate 自动更新 持久化
Hibernate不调用update却自动更新 - 七郎 - 博客园http://www.cnblogs.com/yangy608/p/4073941.html hibernate自动更新持久化类的问 ...