GANs from Scratch 1: A deep introduction. With code in PyTorch and TensorFlow

修改文章代码中的错误后的代码如下:

import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
import matplotlib.pyplot as plt DATA_FOLDER = 'D:/WorkSpace/Data/torchvision_data' def mnist_data():
compose = transforms.Compose(
[transforms.ToTensor(),
# transforms.Normalize((.5, .5, .5), (.5, .5, .5))
transforms.Normalize([0.5], [0.5]) # MNIST只有一个通道
])
return datasets.MNIST(root=DATA_FOLDER, train=True, transform=compose) # Load data
data = mnist_data()
# Create loader with data, so that we can iterate over it
data_loader = torch.utils.data.DataLoader(data, batch_size=64, shuffle=True)
# Num batches
num_batches = len(data_loader) class DiscriminatorNet(torch.nn.Module):
"""
A three hidden-layer discriminative neural network
""" def __init__(self):
super(DiscriminatorNet, self).__init__()
n_features = 784
n_out = 1 self.hidden0 = nn.Sequential(
nn.Linear(n_features, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
self.hidden1 = nn.Sequential(
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
self.hidden2 = nn.Sequential(
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
self.out = nn.Sequential(
torch.nn.Linear(256, n_out),
torch.nn.Sigmoid()
) def forward(self, x):
x = self.hidden0(x)
x = self.hidden1(x)
x = self.hidden2(x)
x = self.out(x)
return x def images_to_vectors(images):
return images.view(images.size(0), 784) def vectors_to_images(vectors):
return vectors.view(vectors.size(0), 1, 28, 28) class GeneratorNet(torch.nn.Module):
"""
A three hidden-layer generative neural network
""" def __init__(self):
super(GeneratorNet, self).__init__()
n_features = 100
n_out = 784 self.hidden0 = nn.Sequential(
nn.Linear(n_features, 256),
nn.LeakyReLU(0.2)
)
self.hidden1 = nn.Sequential(
nn.Linear(256, 512),
nn.LeakyReLU(0.2)
)
self.hidden2 = nn.Sequential(
nn.Linear(512, 1024),
nn.LeakyReLU(0.2)
) self.out = nn.Sequential(
nn.Linear(1024, n_out),
nn.Tanh()
) def forward(self, x):
x = self.hidden0(x)
x = self.hidden1(x)
x = self.hidden2(x)
x = self.out(x)
return x # Noise
def noise(size):
n = Variable(torch.randn(size, 100))
if torch.cuda.is_available(): return n.cuda()
return n discriminator = DiscriminatorNet()
generator = GeneratorNet()
if torch.cuda.is_available():
discriminator.cuda()
generator.cuda() # Optimizers
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002) # Loss function
loss = nn.BCELoss() # Number of steps to apply to the discriminator
d_steps = 1 # In Goodfellow et. al 2014 this variable is assigned to 1
# Number of epochs
num_epochs = 200 def real_data_target(size):
'''
Tensor containing ones, with shape = size
'''
data = Variable(torch.ones(size, 1))
if torch.cuda.is_available(): return data.cuda()
return data def fake_data_target(size):
'''
Tensor containing zeros, with shape = size
'''
data = Variable(torch.zeros(size, 1))
if torch.cuda.is_available(): return data.cuda()
return data def train_discriminator(optimizer, real_data, fake_data):
# Reset gradients
optimizer.zero_grad() # 1.1 Train on Real Data
prediction_real = discriminator(real_data)
# Calculate error and backpropagate
error_real = loss(prediction_real, real_data_target(real_data.size(0)))
error_real.backward() # 1.2 Train on Fake Data
prediction_fake = discriminator(fake_data)
# Calculate error and backpropagate
error_fake = loss(prediction_fake, fake_data_target(real_data.size(0)))
error_fake.backward() # 1.3 Update weights with gradients
optimizer.step() # Return error
return error_real + error_fake, prediction_real, prediction_fake def train_generator(optimizer, fake_data):
# 2. Train Generator
# Reset gradients
optimizer.zero_grad()
# Sample noise and generate fake data
prediction = discriminator(fake_data)
# Calculate error and backpropagate
error = loss(prediction, real_data_target(prediction.size(0)))
error.backward()
# Update weights with gradients
optimizer.step()
# Return error
return error num_test_samples = 16
test_noise = noise(num_test_samples) for epoch in range(num_epochs):
for n_batch, (real_batch,_) in enumerate(data_loader): # 1. Train Discriminator
real_data = Variable(images_to_vectors(real_batch))
if torch.cuda.is_available(): real_data = real_data.cuda()
# Generate fake data
fake_data = generator(noise(real_data.size(0))).detach()
# Train D
d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer,
real_data, fake_data) # 2. Train Generator
# Generate fake data
fake_data = generator(noise(real_batch.size(0)))
# Train G
g_error = train_generator(g_optimizer, fake_data) # Display Progress
print('epoch ', epoch, ': ','d_error is ', d_error, 'g_error is ', g_error)
if (epoch) % 20 == 0:
test_images = vectors_to_images(generator(test_noise)).data.cpu()
fig = plt.figure()
for i in range(len(test_images)):
ax = fig.add_subplot(4, 4, i+1)
ax.imshow(test_images[i][0], cmap=plt.cm.gray)
plt.show()

Implement GAN from scratch的更多相关文章

  1. 机器学习算法之旅A Tour of Machine Learning Algorithms

    In this post we take a tour of the most popular machine learning algorithms. It is useful to tour th ...

  2. ML-学习提纲2

    https://machinelearningmastery.com/a-tour-of-machine-learning-algorithms/ http://blog.csdn.net/u0110 ...

  3. [C4] Andrew Ng - Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization

    About this Course This course will teach you the "magic" of getting deep learning to work ...

  4. [RxJS] Implement the `map` Operator from Scratch in RxJS

    While it's great to use the RxJS built-in operators, it's also important to realize you now have the ...

  5. How to implement an algorithm from a scientific paper

    Author: Emmanuel Goossaert 翻译 This article is a short guide to implementing an algorithm from a scie ...

  6. A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python)

    A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python) MACHINE LEARNING PYTHON  ...

  7. Learning WCF Chapter1 Creating a New Service from Scratch

    You’re about to be introduced to the WCF service. This lab isn’t your typical “Hello World”—it’s “He ...

  8. Developing a Custom Membership Provider from the scratch, and using it in the FBA (Form Based Authentication) in SharePoint 2010

    //http://blog.sharedove.com/adisjugo/index.php/2011/01/05/writing-a-custom-membership-provider-and-u ...

  9. [Laravel] 14 - REST API: Laravel from scratch

    前言 一.基础 Ref: Build a REST API with Laravel API resources Goto: [Node.js] 08 - Web Server and REST AP ...

随机推荐

  1. maven项目转换为gradle项目

    进入到项目更目录,运行 gradle init --type pom 上面的命令会根据pom文件自动生成gradle项目所需的文件和配置,然后以gradle项目重新导入即可.

  2. 08 nginx+uWSGI+django+virtualenv+supervisor发布web服务器

    一.为什么要用nginx,uwsgi? 1 首先nginx 是对外的服务接口,外部浏览器通过url访问nginx, 2nginx 接收到浏览器发送过来的http请求,将包进行解析,分析url,如果是静 ...

  3. 原生JS+CSS实现日期插件

    笔者最近在学习Element UI,觉得它提供的日期选择器既简单又美观,于是仿照着写了一个日期插件.笔者使用到的技术有ES5.CSS和HTML,控件兼容IE10+和谷歌浏览器.有一点需要注意,笔者使用 ...

  4. Caffe常用算子GPU和CPU对比

    通过整理LeNet.AlexNet.VGG16.googLeNet.ResNet.MLP统计出的常用算子(不包括ReLU),表格是对比. Prelu Cpu版 Gpu版 for (int i = 0; ...

  5. Java高并发程序设计学习笔记(一):并行简介以及重要概念

    转自:https://blog.csdn.net/dataiyangu/article/details/86211544#_28 文章目录为什么需要并行?反对意见大势所趋几个重要的概念同步(synch ...

  6. 用python编写一个合格的ftp程序,思路是怎样的?

      经验1.一般在比较正规的类中的构造函数.都会有一个verify_args函数,用于验证传入参数.尤其是对于系统传参.2.并且系统传参,其实后面大概都是一个函数名 例如:python server. ...

  7. shell中字符串操作【转】

    转自:http://blog.chinaunix.net/uid-29091195-id-3974751.html 我们所遇到的编程语言中(汇编除外)都少不了字符串处理函数吧,当然shell编程也不例 ...

  8. 牛客练习赛47 E DongDong数颜色 (树上启发式合并)

    链接:https://ac.nowcoder.com/acm/contest/904/E 来源:牛客网 DongDong数颜色 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 5242 ...

  9. arm开发板make编译时遇到 make[2]:*** [s-attrtab] 已杀死 问题的解决方案

    未验证 出现“make[2]: *** [s-attrtab] 已杀死”log 是由于内存不足 解决方案 增加swapfile 步骤如下: 1. 查看当前swapfile状态 root@ubuntu: ...

  10. The Preliminary Contest for ICPC Asia Nanchang 2019 B. Fire-Fighting Hero

    题目:https://nanti.jisuanke.com/t/41349 思路:dijkstra最短路径 先以 fire-fighting hero为起点 跑一遍dijkstra 建立 起点 p 并 ...