Implement GAN from scratch
GANs from Scratch 1: A deep introduction. With code in PyTorch and TensorFlow
修改文章代码中的错误后的代码如下:
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
DATA_FOLDER = 'D:/WorkSpace/Data/torchvision_data'
def mnist_data():
compose = transforms.Compose(
[transforms.ToTensor(),
# transforms.Normalize((.5, .5, .5), (.5, .5, .5))
transforms.Normalize([0.5], [0.5]) # MNIST只有一个通道
])
return datasets.MNIST(root=DATA_FOLDER, train=True, transform=compose)
# Load data
data = mnist_data()
# Create loader with data, so that we can iterate over it
data_loader = torch.utils.data.DataLoader(data, batch_size=64, shuffle=True)
# Num batches
num_batches = len(data_loader)
class DiscriminatorNet(torch.nn.Module):
"""
A three hidden-layer discriminative neural network
"""
def __init__(self):
super(DiscriminatorNet, self).__init__()
n_features = 784
n_out = 1
self.hidden0 = nn.Sequential(
nn.Linear(n_features, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
self.hidden1 = nn.Sequential(
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
self.hidden2 = nn.Sequential(
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
self.out = nn.Sequential(
torch.nn.Linear(256, n_out),
torch.nn.Sigmoid()
)
def forward(self, x):
x = self.hidden0(x)
x = self.hidden1(x)
x = self.hidden2(x)
x = self.out(x)
return x
def images_to_vectors(images):
return images.view(images.size(0), 784)
def vectors_to_images(vectors):
return vectors.view(vectors.size(0), 1, 28, 28)
class GeneratorNet(torch.nn.Module):
"""
A three hidden-layer generative neural network
"""
def __init__(self):
super(GeneratorNet, self).__init__()
n_features = 100
n_out = 784
self.hidden0 = nn.Sequential(
nn.Linear(n_features, 256),
nn.LeakyReLU(0.2)
)
self.hidden1 = nn.Sequential(
nn.Linear(256, 512),
nn.LeakyReLU(0.2)
)
self.hidden2 = nn.Sequential(
nn.Linear(512, 1024),
nn.LeakyReLU(0.2)
)
self.out = nn.Sequential(
nn.Linear(1024, n_out),
nn.Tanh()
)
def forward(self, x):
x = self.hidden0(x)
x = self.hidden1(x)
x = self.hidden2(x)
x = self.out(x)
return x
# Noise
def noise(size):
n = Variable(torch.randn(size, 100))
if torch.cuda.is_available(): return n.cuda()
return n
discriminator = DiscriminatorNet()
generator = GeneratorNet()
if torch.cuda.is_available():
discriminator.cuda()
generator.cuda()
# Optimizers
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
# Loss function
loss = nn.BCELoss()
# Number of steps to apply to the discriminator
d_steps = 1 # In Goodfellow et. al 2014 this variable is assigned to 1
# Number of epochs
num_epochs = 200
def real_data_target(size):
'''
Tensor containing ones, with shape = size
'''
data = Variable(torch.ones(size, 1))
if torch.cuda.is_available(): return data.cuda()
return data
def fake_data_target(size):
'''
Tensor containing zeros, with shape = size
'''
data = Variable(torch.zeros(size, 1))
if torch.cuda.is_available(): return data.cuda()
return data
def train_discriminator(optimizer, real_data, fake_data):
# Reset gradients
optimizer.zero_grad()
# 1.1 Train on Real Data
prediction_real = discriminator(real_data)
# Calculate error and backpropagate
error_real = loss(prediction_real, real_data_target(real_data.size(0)))
error_real.backward()
# 1.2 Train on Fake Data
prediction_fake = discriminator(fake_data)
# Calculate error and backpropagate
error_fake = loss(prediction_fake, fake_data_target(real_data.size(0)))
error_fake.backward()
# 1.3 Update weights with gradients
optimizer.step()
# Return error
return error_real + error_fake, prediction_real, prediction_fake
def train_generator(optimizer, fake_data):
# 2. Train Generator
# Reset gradients
optimizer.zero_grad()
# Sample noise and generate fake data
prediction = discriminator(fake_data)
# Calculate error and backpropagate
error = loss(prediction, real_data_target(prediction.size(0)))
error.backward()
# Update weights with gradients
optimizer.step()
# Return error
return error
num_test_samples = 16
test_noise = noise(num_test_samples)
for epoch in range(num_epochs):
for n_batch, (real_batch,_) in enumerate(data_loader):
# 1. Train Discriminator
real_data = Variable(images_to_vectors(real_batch))
if torch.cuda.is_available(): real_data = real_data.cuda()
# Generate fake data
fake_data = generator(noise(real_data.size(0))).detach()
# Train D
d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer,
real_data, fake_data)
# 2. Train Generator
# Generate fake data
fake_data = generator(noise(real_batch.size(0)))
# Train G
g_error = train_generator(g_optimizer, fake_data)
# Display Progress
print('epoch ', epoch, ': ','d_error is ', d_error, 'g_error is ', g_error)
if (epoch) % 20 == 0:
test_images = vectors_to_images(generator(test_noise)).data.cpu()
fig = plt.figure()
for i in range(len(test_images)):
ax = fig.add_subplot(4, 4, i+1)
ax.imshow(test_images[i][0], cmap=plt.cm.gray)
plt.show()
Implement GAN from scratch的更多相关文章
- 机器学习算法之旅A Tour of Machine Learning Algorithms
In this post we take a tour of the most popular machine learning algorithms. It is useful to tour th ...
- ML-学习提纲2
https://machinelearningmastery.com/a-tour-of-machine-learning-algorithms/ http://blog.csdn.net/u0110 ...
- [C4] Andrew Ng - Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization
About this Course This course will teach you the "magic" of getting deep learning to work ...
- [RxJS] Implement the `map` Operator from Scratch in RxJS
While it's great to use the RxJS built-in operators, it's also important to realize you now have the ...
- How to implement an algorithm from a scientific paper
Author: Emmanuel Goossaert 翻译 This article is a short guide to implementing an algorithm from a scie ...
- A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python)
A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python) MACHINE LEARNING PYTHON ...
- Learning WCF Chapter1 Creating a New Service from Scratch
You’re about to be introduced to the WCF service. This lab isn’t your typical “Hello World”—it’s “He ...
- Developing a Custom Membership Provider from the scratch, and using it in the FBA (Form Based Authentication) in SharePoint 2010
//http://blog.sharedove.com/adisjugo/index.php/2011/01/05/writing-a-custom-membership-provider-and-u ...
- [Laravel] 14 - REST API: Laravel from scratch
前言 一.基础 Ref: Build a REST API with Laravel API resources Goto: [Node.js] 08 - Web Server and REST AP ...
随机推荐
- 从入门到自闭之Python三大器--生成器
1.什么是生成器 核心:生成器的本质就是一个迭代器 迭代器是python自带的的 生成器是程序员自己写的一种迭代器 编写方式: 基于函数编写 推导式编写 def func (): print(&quo ...
- 面试题1-十进制数转化为十六进制数,不使用hex方法
问题: 给定一个整数,写一个算法将它转换为16进制,对于负数,可以使用two’s complement方法 def tohex(num): """十进制数转十六进制数&q ...
- Spring 中的bean 是线程安全的吗?
结论: 不是线程安全的 Spring容器中的Bean是否线程安全,容器本身并没有提供Bean的线程安全策略,因此可以说Spring容器中的Bean本身不具备线程安全的特性,但是具体还是要结合具体sco ...
- docopt 安装及基本应用
什么是 docopt docopt是一种python 编写的命令行执行脚本的交互语言. 它是一种语言! 它是一种语言! 它是一种语言! 使用这种语言可以在自己的脚本中,添加一些规则限制,这样脚本在执行 ...
- windows下生成zlib1.dll
一.原料: VC zlib-1.2.3-src.zip 二.解压zlib-1.2.3-src.zip,用VC打开工作空间 src/zlib/1.2.3/zlib-1.2.3/projects/visu ...
- spring boot本地开发与docker容器化部署的差异
spring boot本地开发与docker容器化部署的差异: 1. 文件路径及文件名区别大小写: 本地开发环境为windows操作系统,是忽略大小写的,但容器中区分大小写 2. docker中的容器 ...
- C# 面向对象5 this关键字和析构函数
this关键字 1.代表当前类的对象 2.在类当中显示的调用本类的构造函数(避免代码的冗余) 语法: ":this" 以下一个参数的构造函数调用了参数最全的构造函数!并赋值了那些不 ...
- 移动端实1px细线方法
前言 在移动端中,宽度100%,1px的线看起来要比pc端中宽度100%,1px的线粗, 那是因为css中的1px并不等于移动设备(物理像素)的1px.物理像素显示是1个像素代表2个像素,所以出现为2 ...
- JS基础_字面量和变量
<!DOCTYPE html> <html> <head> <meta charset="UTF-8"> <title> ...
- AGC009E Eternal Average
atc 神题orz 那个擦掉\(k\)个数然后写上一个平均值可以看成是\(k\)叉Huffman树的构造过程,每次选\(k\)个点合成一个新点,然后权值设为平均值.这些0和1都会在叶子的位置,同时每个 ...