点击查看代码
import argparse
import os import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_image os.makedirs("images", exist_ok=True) parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt) cuda = True if torch.cuda.is_available() else False def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0) class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__() self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim) self.init_size = opt.img_size // 4 # Initial size before upsampling
self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2)) self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
nn.Tanh(),
) def forward(self, noise, labels):
gen_input = torch.mul(self.label_emb(labels), noise)
out = self.l1(gen_input)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__() def discriminator_block(in_filters, out_filters, bn=True):
"""Returns layers of each discriminator block"""
block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
return block self.conv_blocks = nn.Sequential(
*discriminator_block(opt.channels, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
) # The height and width of downsampled image
ds_size = opt.img_size // 2 ** 4 # Output layers
self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax()) def forward(self, img):
out = self.conv_blocks(img)
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)
label = self.aux_layer(out) return validity, label # Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss() # Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator() if cuda:
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
auxiliary_loss.cuda() # Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal) # Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"../../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
) # Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor def sample_image(n_row, batches_done):
"""Saves a grid of generated digits ranging from 0 to n_classes"""
# Sample noise
z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
# Get labels ranging from 0 to n_classes for n rows
labels = np.array([num for _ in range(n_row) for num in range(n_row)])
labels = Variable(LongTensor(labels))
gen_imgs = generator(z, labels)
save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True) for epoch in range(opt.n_epochs):
for i, (imgs, labels) in enumerate(dataloader): batch_size = imgs.shape[0] # Adversarial ground truths
valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False) # Configure input
real_imgs = Variable(imgs.type(FloatTensor))
labels = Variable(labels.type(LongTensor)) # -----------------
# Train Generator
# ----------------- optimizer_G.zero_grad() # Sample noise and labels as generator input
z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size))) # Generate a batch of images
gen_imgs = generator(z, gen_labels) # Loss measures generator's ability to fool the discriminator
validity, pred_label = discriminator(gen_imgs)
g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels)) g_loss.backward()
optimizer_G.step() # ---------------------
# Train Discriminator
# --------------------- optimizer_D.zero_grad() # Loss for real images
real_pred, real_aux = discriminator(real_imgs)
d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2 # Loss for fake images
fake_pred, fake_aux = discriminator(gen_imgs.detach())
d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2 # Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2 # Calculate discriminator accuracy
pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)
gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)
d_acc = np.mean(np.argmax(pred, axis=1) == gt) d_loss.backward()
optimizer_D.step() print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
)
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
sample_image(n_row=10, batches_done=batches_done)

ACGAN-pytorch的更多相关文章

  1. Ubutnu16.04安装pytorch

    1.下载Anaconda3 首先需要去Anaconda官网下载最新版本Anaconda3(https://www.continuum.io/downloads),我下载是是带有python3.6的An ...

  2. 解决运行pytorch程序多线程问题

    当我使用pycharm运行  (https://github.com/Joyce94/cnn-text-classification-pytorch )  pytorch程序的时候,在Linux服务器 ...

  3. 基于pytorch实现word2vec

    一.介绍 word2vec是Google于2013年推出的开源的获取词向量word2vec的工具包.它包括了一组用于word embedding的模型,这些模型通常都是用浅层(两层)神经网络训练词向量 ...

  4. 基于pytorch的CNN、LSTM神经网络模型调参小结

    (Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...

  5. pytorch实现VAE

    一.VAE的具体结构 二.VAE的pytorch实现 1加载并规范化MNIST import相关类: from __future__ import print_function import argp ...

  6. PyTorch教程之Training a classifier

    我们已经了解了如何定义神经网络,计算损失并对网络的权重进行更新. 接下来的问题就是: 一.What about data? 通常处理图像.文本.音频或视频数据时,可以使用标准的python包将数据加载 ...

  7. PyTorch教程之Neural Networks

    我们可以通过torch.nn package构建神经网络. 现在我们已经了解了autograd,nn基于autograd来定义模型并对他们有所区分. 一个 nn.Module模块由如下部分构成:若干层 ...

  8. PyTorch教程之Autograd

    在PyTorch中,autograd是所有神经网络的核心内容,为Tensor所有操作提供自动求导方法. 它是一个按运行方式定义的框架,这意味着backprop是由代码的运行方式定义的. 一.Varia ...

  9. Linux安装pytorch的具体过程以及其中出现问题的解决办法

    1.安装Anaconda 安装步骤参考了官网的说明:https://docs.anaconda.com/anaconda/install/linux.html 具体步骤如下: 首先,在官网下载地址 h ...

  10. Highway Networks Pytorch

    导读 本文讨论了深层神经网络训练困难的原因以及如何使用Highway Networks去解决深层神经网络训练的困难,并且在pytorch上实现了Highway Networks. 一 .Highway ...

随机推荐

  1. Go语言的赞和喷

    (原发于 GitHub Pages,2019-01-01 23:22:43) 2019 年,我回来了. 不知不觉中,我入 PHP 的坑已经 3 年有余,入 Go 的坑也大半年了.作为不评论不舒服斯基星 ...

  2. vulnhub靶场之FUNBOX: UNDER CONSTRUCTION!

    准备: 攻击机:虚拟机kali.本机win10. 靶机:Funbox: Under Construction!,下载地址:https://download.vulnhub.com/funbox/Fun ...

  3. [Python]Python调用Matlab (Pycharm版本)

    目录 第一步:生成Build文件夹 第二步: 复制build文件夹到Pycharm下 第三步:调用代码 第一步:生成Build文件夹 C:\Program Files\MATLAB\R2016a\ex ...

  4. [剑指Offer]3.数组中重复的数字

    题目 找出数组中重复的数字. 在一个长度为n的数组中的所有数字都在0~n-1的范围内.数组中某些数字是重复的,但是不知道有几个数字重复了,也不知道每个数字重复了几次.请找出数组中任意一个重复的数组.例 ...

  5. .Net开发的系统安装或更新时如何避免覆盖用户自定义的配置

    我们开发的系统,有时候会包含一些配置信息,需要用户在系统安装后自己去设置,例如我们有一个GPExSettings.xml文件,内容如下. <GPExSettings ArcPythonPath= ...

  6. 【Azure 存储服务】.NET7.0 示例代码之上传大文件到Azure Storage Blob

    问题描述 在使用Azure的存储服务时候,如果上传的文件大于了100MB, 1GB的情况下,如何上传呢? 问题解答 使用Azure存储服务时,如果要上传文件到Azure Blob,有很多种工具可以实现 ...

  7. .Net 7 托管Main入口的四种类型

    前言: 按照CLR的规则,C#托管Main入口其实有四种写法. 写法 第一种:最常见的,也是VS默认的,返回值void,带一个参数 static void Main(string[] args) { ...

  8. echarts入门到应用学习笔记

    背景: 做疫情数据管理可视化,需要用到热点图在web端进行数据可视化,而地图就是必不可少的一个,看完文档,可以解决大部分小白的问题,保姆级攻略,即使你的js,这些学得不咋样(我就是小菜鸟) 步骤 环境 ...

  9. 还不来了解ChatGPT?免费账号

      可以查看这里给大家提供了一些免费的账号供大家尝试 note.youdao.com/s/OvxaLZiF   ChatGPT作为最近火遍互联网的AI项目,获得了大家空前的关注,短短两个多月注册人数破 ...

  10. JavaScript的闭包和作用域

    作用域相关 作用域的概念: 作用域是在运行时代码中的某些特定部分中变量,函数和对象的可访问性.换句话说,作用域决定了代码区块中变量和其他资源的可见性: 作用域的类型: 全局作用域: 最外层函数和在最外 ...