点击查看代码
import argparse
import os import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_image os.makedirs("images", exist_ok=True) parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt) cuda = True if torch.cuda.is_available() else False def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0) class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__() self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim) self.init_size = opt.img_size // 4 # Initial size before upsampling
self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2)) self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
nn.Tanh(),
) def forward(self, noise, labels):
gen_input = torch.mul(self.label_emb(labels), noise)
out = self.l1(gen_input)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__() def discriminator_block(in_filters, out_filters, bn=True):
"""Returns layers of each discriminator block"""
block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
return block self.conv_blocks = nn.Sequential(
*discriminator_block(opt.channels, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
) # The height and width of downsampled image
ds_size = opt.img_size // 2 ** 4 # Output layers
self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax()) def forward(self, img):
out = self.conv_blocks(img)
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)
label = self.aux_layer(out) return validity, label # Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss() # Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator() if cuda:
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
auxiliary_loss.cuda() # Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal) # Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"../../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
) # Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor def sample_image(n_row, batches_done):
"""Saves a grid of generated digits ranging from 0 to n_classes"""
# Sample noise
z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
# Get labels ranging from 0 to n_classes for n rows
labels = np.array([num for _ in range(n_row) for num in range(n_row)])
labels = Variable(LongTensor(labels))
gen_imgs = generator(z, labels)
save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True) for epoch in range(opt.n_epochs):
for i, (imgs, labels) in enumerate(dataloader): batch_size = imgs.shape[0] # Adversarial ground truths
valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False) # Configure input
real_imgs = Variable(imgs.type(FloatTensor))
labels = Variable(labels.type(LongTensor)) # -----------------
# Train Generator
# ----------------- optimizer_G.zero_grad() # Sample noise and labels as generator input
z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size))) # Generate a batch of images
gen_imgs = generator(z, gen_labels) # Loss measures generator's ability to fool the discriminator
validity, pred_label = discriminator(gen_imgs)
g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels)) g_loss.backward()
optimizer_G.step() # ---------------------
# Train Discriminator
# --------------------- optimizer_D.zero_grad() # Loss for real images
real_pred, real_aux = discriminator(real_imgs)
d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2 # Loss for fake images
fake_pred, fake_aux = discriminator(gen_imgs.detach())
d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2 # Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2 # Calculate discriminator accuracy
pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)
gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)
d_acc = np.mean(np.argmax(pred, axis=1) == gt) d_loss.backward()
optimizer_D.step() print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
)
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
sample_image(n_row=10, batches_done=batches_done)

ACGAN-pytorch的更多相关文章

  1. Ubutnu16.04安装pytorch

    1.下载Anaconda3 首先需要去Anaconda官网下载最新版本Anaconda3(https://www.continuum.io/downloads),我下载是是带有python3.6的An ...

  2. 解决运行pytorch程序多线程问题

    当我使用pycharm运行  (https://github.com/Joyce94/cnn-text-classification-pytorch )  pytorch程序的时候,在Linux服务器 ...

  3. 基于pytorch实现word2vec

    一.介绍 word2vec是Google于2013年推出的开源的获取词向量word2vec的工具包.它包括了一组用于word embedding的模型,这些模型通常都是用浅层(两层)神经网络训练词向量 ...

  4. 基于pytorch的CNN、LSTM神经网络模型调参小结

    (Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...

  5. pytorch实现VAE

    一.VAE的具体结构 二.VAE的pytorch实现 1加载并规范化MNIST import相关类: from __future__ import print_function import argp ...

  6. PyTorch教程之Training a classifier

    我们已经了解了如何定义神经网络,计算损失并对网络的权重进行更新. 接下来的问题就是: 一.What about data? 通常处理图像.文本.音频或视频数据时,可以使用标准的python包将数据加载 ...

  7. PyTorch教程之Neural Networks

    我们可以通过torch.nn package构建神经网络. 现在我们已经了解了autograd,nn基于autograd来定义模型并对他们有所区分. 一个 nn.Module模块由如下部分构成:若干层 ...

  8. PyTorch教程之Autograd

    在PyTorch中,autograd是所有神经网络的核心内容,为Tensor所有操作提供自动求导方法. 它是一个按运行方式定义的框架,这意味着backprop是由代码的运行方式定义的. 一.Varia ...

  9. Linux安装pytorch的具体过程以及其中出现问题的解决办法

    1.安装Anaconda 安装步骤参考了官网的说明:https://docs.anaconda.com/anaconda/install/linux.html 具体步骤如下: 首先,在官网下载地址 h ...

  10. Highway Networks Pytorch

    导读 本文讨论了深层神经网络训练困难的原因以及如何使用Highway Networks去解决深层神经网络训练的困难,并且在pytorch上实现了Highway Networks. 一 .Highway ...

随机推荐

  1. python进阶之路6之 for循环方法

    while循环补充说明 1.死循环 真正的死循环是一旦执行 CPU功耗会极速上升 直到系统采取紧急措施 尽量不要让CPU长时间不间断运算 2.嵌套及全局标志位 强调:一个break只能结束它所在的那一 ...

  2. Java学习笔记:2021年12月31日下午-2022年1月1日上午

    Java学习笔记:2021年12月31日下午-2022年1月1日上午 摘要:主要记录了计算机的电气构成,学习Linux系统的原因以及关于Linux以及相关操作的基础知识. 目录 Java学习笔记:20 ...

  3. 01-Verilog基础

    Verilog RTL编程实践 在进行数字IC设计过程中,RTL coding能力是非常重要的.结合逻辑仿真(VCS)和逻辑综合(Design Compiler)工具.看RTL. 1 ASIC Des ...

  4. 刷题笔记——2758.打印ASCII码 & 2759.打印字符

    题目 2758.打印ASCII码 2759.打印字符 代码 while True: try: a = input() print(ord(a)) except: break while True: t ...

  5. 移动 WEB 开发之 阿里百秀移动端页面制作

    一.技术选型 二.需求分析 1.页面布局分析 2. 屏幕划分 三.页面制作 1. 项目前期准备 搭建项目结构 创建 html 骨架结构以及引入相关样式 <head> <meta ch ...

  6. 论文翻译:2022_Phase-Aware Deep Speech Enhancement: It’s All About The Frame Length

    摘要 虽然相位感知语音处理近年来受到越来越多的关注,但大多数帧长约为32 ms的窄带STFT方法显示出相位对整体性能的影响相当温和.与此同时,现代基于深度神经网络(DNN)的方法,如Conv-TasN ...

  7. 使用pycharm打开sqlite的问题

    目录 问题:有同学在sqlite数据库文件执行数据库迁移完成前,点开了他,导致sqlite数据库被pycharm当成文本文件打开了,并且不会改了. 其实sqlite文件和电脑中的其他文件(xx.mp4 ...

  8. MySQL-知识点补充

    1.SQL注入问题 简单实现利用数据库实现注册登录功能: import pymysql conn = pymysql.connect( host='127.0.0.1', port=3306, use ...

  9. 郁金香-了解MFC信息机制

    控件的事件 窗口的信息

  10. js中的Object.keys、array.map、groupBy、call、apply总结分享

    分享几个js中的函数 Object.keys() 首先这个函数是用来干嘛的呢?是用来把一个json字符串里的key全都取出来重新整成一个数组的方法,那么这个函数怎么用呢,接下来贴出我最近碰见的用法: ...