一、VAE的具体结构

二、VAE的pytorch实现

1加载并规范化MNIST

import相关类:

from __future__ import print_function
import argparse
import torch
import torch.utils.data
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms

设置参数:

parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
print(args) #Sets the seed for generating random numbers. And returns a torch._C.Generator object.
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)

输出结果:

Namespace(batch_size=128, cuda=True, epochs=10, log_interval=10, no_cuda=False, seed=1)

下载数据集到./data/目录下:

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
trainset = datasets.MNIST('../data', train=True, download=True,transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(
trainset,
batch_size=args.batch_size, shuffle=True, **kwargs)
testset= datasets.MNIST('../data', train=False, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(
testset,
batch_size=args.batch_size, shuffle=True, **kwargs)
image, label = trainset[0]
print(len(trainset))
print(image.size())
image, label = testset[0]
print(len(testset))
print(image.size())

输出结果:

60000
torch.Size([1, 28, 28])
10000
torch.Size([1, 28, 28])

2定义VAE

首先我们介绍x.view方法:

x = torch.randn(4, 4)y = x.view(16)z = x.view(-1, 16)  # the size -1 is inferred from other dimensions
print(x)
print(y)
print(z)

输出结果:

 1.6154  1.1792  0.6450  1.2078
-0.4741 1.2145 0.8381 2.3532
0.2070 -0.9054 0.9262 0.6758
1.2613 0.5196 -1.7125 -0.0519
[torch.FloatTensor of size 4x4]
1.6154
1.1792
0.6450
1.2078
-0.4741
1.2145
0.8381
2.3532
0.2070
-0.9054
0.9262
0.6758
1.2613
0.5196
-1.7125
-0.0519
[torch.FloatTensor of size 16]
Columns 0 to 9
1.6154 1.1792 0.6450 1.2078 -0.4741 1.2145 0.8381 2.3532 0.2070 -0.9054 Columns 10 to 15
0.9262 0.6758 1.2613 0.5196 -1.7125 -0.0519
[torch.FloatTensor of size 1x16]

然后建立VAE模型

class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__() self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20)
self.fc22 = nn.Linear(400, 20)
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784) self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid() def encode(self, x):
h1 = self.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1) def reparametrize(self, mu, logvar):
std = logvar.mul(0.5).exp_()
eps = Variable(std.data.new(std.size()).normal_())
return eps.mul(std).add_(mu) def decode(self, z):
h3 = self.relu(self.fc3(z))
return self.sigmoid(self.fc4(h3)) def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparametrize(mu, logvar)
return self.decode(z), mu, logvar model = VAE()
if args.cuda:
model.cuda()

3.定义一个损失函数


reconstruction_function = nn.BCELoss()
reconstruction_function.size_average = False def loss_function(recon_x, x, mu, logvar):
BCE = reconstruction_function(recon_x, x.view(-1, 784)) # see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
KLD = torch.sum(KLD_element).mul_(-0.5) return BCE + KLD optimizer = optim.Adam(model.parameters(), lr=1e-3)

4.在训练数据上训练神经网络

我们只需要对数据迭代器进行循环,并将输入反馈到网络并进行优化。

for epoch in range(1, args.epochs + 1):
train(epoch)
test(epoch)

其中

def train(epoch):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = Variable(data)
if args.cuda:
data = data.cuda()
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.data[0]
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.data[0] / len(data))) print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset))) def test(epoch):
model.eval()
test_loss = 0
for data, _ in test_loader:
if args.cuda:
data = data.cuda()
data = Variable(data, volatile=True)
recon_batch, mu, logvar = model(data)
test_loss += loss_function(recon_batch, data, mu, logvar).data[0] test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))

Tips:

1.直接运行pytorch examples里的代码发现library not initialized at /pytorch/torch/lib/THC/THCGeneral.c错误

解决方案:sudo rm -r ~/.nv

2.该源码实现的论文为https://arxiv.org/pdf/1312.6114.pdf

pytorch实现VAE的更多相关文章

  1. Pytorch入门之VAE

    关于自编码器的原理见另一篇博客 : 编码器AE & VAE 这里谈谈对于变分自编码器(Variational auto-encoder)即VAE的实现. 1. 稀疏编码 首先介绍一下“稀疏编码 ...

  2. Variational Auto-encoder(VAE)变分自编码器-Pytorch

    import os import torch import torch.nn as nn import torch.nn.functional as F import torchvision from ...

  3. pytorch实现DCGAN、pix2pix、DiscoGAN、CycleGAN、BEGAN以及VAE

    https://github.com/sunshineatnoon/Paper-Implementations

  4. Pytorch 细节记录

    1. PyTorch进行训练和测试时指定实例化的model模式为:train/eval eg: class VAE(nn.Module): def __init__(self): super(VAE, ...

  5. 【转载】 Pytorch 细节记录

    原文地址: https://www.cnblogs.com/king-lps/p/8570021.html ---------------------------------------------- ...

  6. (转)Awesome PyTorch List

    Awesome-Pytorch-list 2018-08-10 09:25:16 This blog is copied from: https://github.com/Epsilon-Lee/Aw ...

  7. (转) The Incredible PyTorch

    转自:https://github.com/ritchieng/the-incredible-pytorch The Incredible PyTorch What is this? This is ...

  8. pytorch实现autoencoder

    关于autoencoder的内容简介可以参考这一篇博客,可以说写的是十分详细了https://sherlockliao.github.io/2017/06/24/vae/ 盗图一张,自动编码器讲述的是 ...

  9. 库、教程、论文实现,这是一份超全的PyTorch资源列表(Github 2.2K星)

    项目地址:https://github.com/bharathgs/Awesome-pytorch-list 列表结构: NLP 与语音处理 计算机视觉 概率/生成库 其他库 教程与示例 论文实现 P ...

随机推荐

  1. jenkins+ant+jmeter接口自动化测试(持续构建)

    使用badboy录制脚本,到处到jmeter后进行接口自动化,后来想着 可不可以用自动化来跑脚本呢,不用jmeter的图形界面呢, 选择了ant来进行构建,最后想到了用Jenkins来进行持续构建接口 ...

  2. FreeMarker简介

    什么是 FreeMarker? FreeMarker 是一款 模板引擎: 即一种基于模板和要改变的数据, 并用来生成输出文本(HTML网页,电子邮件,配置文件,源代码等)的通用工具. 它不是面向最终用 ...

  3. hdu 6194 沈阳网络赛--string string string(后缀数组)

    题目链接 Problem Description Uncle Mao is a wonderful ACMER. One day he met an easy problem, but Uncle M ...

  4. CVE-2016-10190 FFmpeg Http协议 heap buffer overflow漏洞分析及利用

    作者:栈长@蚂蚁金服巴斯光年安全实验室 -------- 1. 背景 FFmpeg是一个著名的处理音视频的开源项目,非常多的播放器.转码器以及视频网站都用到了FFmpeg作为内核或者是处理流媒体的工具 ...

  5. 使用sql语句复制一张表

    如何使用sql语句复制一张表? 方法一:第一步:先建一张新表,新表的结构与老表相等. create table newbiao like chengjibiao(老表名); 第二步:将老表中的值复制到 ...

  6. 【★】致全球第一批全帧3D游戏!

    图一 游戏片头 致逝去的青春记忆. 好久没人玩Ballance了吧,贴吧里貌似早已冷掉了. 作为一款经典游戏,Ballance的宣传却做得不到位,官方的介绍甚至没能展现出它的全部诱人之处.所以笔者决 ...

  7. dnsmasq一次成功的配置

    第一次用这个小软件,感觉还不错,因为没有像bind那样配置起来繁琐,并且我们也不需要去配置很多文件,内外网访问互不干涉. 我是在centos6.5下进行配置的: 先说说自己的理解: dnsmasq先去 ...

  8. 第二次项目冲刺(Beta阶段)第一天

    a. 安排连续七天的敏捷冲刺. 2017.5.18完成冲刺计划安排 2017.5.20完善主页面 1st day(目前位置) 2017.5.21完善功能 2st day 2017.5.22添加自定义重 ...

  9. Swing-JFileChooser的使用

    JFileChooser文件选择器是Swing中经常用到的一个控件.它的使用主要包含以下几个参数: 1.当前路径.也就是它第一次打开时所在的路径,许多软件喜欢设置为桌面. 2.文件过滤器.通过设置文件 ...

  10. 201521123067 《Java程序设计》第7周学习总结

    201521123067 <Java程序设计>第7周学习总结 1. 本周学习总结 以你喜欢的方式(思维导图或其他)归纳总结集合相关内容. 2. 书面作业 Q1.ArrayList代码分析 ...