//加上了注释,对pytorch又加深了理解
import torch as t
from torch import nn
from torch.autograd import Variable
from torch.optim import Adam
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from pylab import plt #pylab结合了pyplot和numpy class Config:
lr = 0.0002
nz = 100 #噪声维度
image_size = 64
image_size2 = 64
nc = 3 #图片是三通道的
ngf = 64 #G的特征层数
ndf = 64 #D的特征层数
beta1 = 0.5
batch_size = 32
max_epoch = 10
workers = 0
gpu = True opt = Config() #载入数据
transform = transforms.Compose([
transforms.Resize(opt.image_size),
transforms.ToTensor(),
transforms.Normalize([0.5]*3,[0.5]*3) #均值&标准差
]) dataset = CIFAR10(root='cifar10/',transform=transform,download=True)
dataloader = DataLoader(dataset,opt.batch_size,shuffle=True,num_workers=opt.workers) #输入的是噪声图片的维度
netg = nn.Sequential(
nn.ConvTranspose2d(opt.nz,opt.ngf*8,4,1,0,bias=False),
nn.BatchNorm2d(opt.ngf*8),
nn.ReLU(True), nn.ConvTranspose2d(opt.ngf*8,opt.ngf*4,4,2,1,bias=False),
nn.BatchNorm2d(opt.ngf*4),
nn.ReLU(True), nn.ConvTranspose2d(opt.ngf*4,opt.ngf*2,4,2,1,bias=False),
nn.BatchNorm2d(opt.ngf*2),
nn.ReLU(True), nn.ConvTranspose2d(opt.ngf*2,opt.ngf,4,2,1,bias=False),
nn.BatchNorm2d(opt.ngf),
nn.ReLU(True), nn.ConvTranspose2d(opt.ngf,opt.nc,4,2,1,bias=False),
nn.Tanh() #输出的是FAKE图片的维度
) netd = nn.Sequential(
nn.Conv2d(opt.nc,opt.ndf,4,2,1,bias=False),
nn.LeakyReLU(0.2,inplace=True), nn.Conv2d(opt.ndf, opt.ndf*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(opt.ndf*2),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(opt.ndf*2, opt.ndf*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(opt.ndf * 4),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(opt.ndf*4, opt.ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(opt.ndf * 8),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(opt.ndf*8,1,4,1,0,bias=False),
nn.Sigmoid()
) #optimizer
optimizerD = Adam(netd.parameters(),lr = opt.lr,betas=(opt.beta1,0.999))
optimizerG = Adam(netg.parameters(),lr = opt.lr,betas=(opt.beta1,0.999)) #criterion
criterion = nn.BCELoss() fix_noise = Variable(t.FloatTensor(opt.batch_size,opt.nz,1,1).normal_(0,1))#高斯分布N(0,1)
if opt.gpu:
fix_noise = fix_noise.cuda()
netd.cuda()
netg.cuda()
criterion.cuda() print("开始训练") for epoch in range(opt.max_epoch):
for ii,data in enumerate(dataloader,start=0):
real,_ = data
input = Variable(real)
label = Variable(t.ones(input.size(0)))#一开始训练DIS用real image 所以给的label都是1,所以这个label大小和batch_size大小一样
noise = t.randn(input.size(0),opt.nz,1,1)#不是很理解后面两个1是干啥用的
noise = Variable(noise) if opt.gpu:
noise = noise.cuda()
input = input.cuda()
label = label.cuda() #____train disc____
netd.zero_grad()
#用real image train
output = netd(input)
loss_real = criterion(output.squeeze(),label)#output 与 1之间的loss
loss_real.backward()
# D_x = output.data.mean()#这是平均loss
#用fake image train
fake_pic = netg(noise).detach()#截断反向传播,只影响G不影响D
output2 = netd(fake_pic)
label.data.fill_(0) #把label的1改成0,因为是fake image
loss_fake = criterion(output2.squeeze(),label)
loss_fake.backward()
# D_x2 = output2.data.mean()
error_D = loss_real+loss_fake
optimizerD.step() #_____train generator__
netg.zero_grad()
label.data.fill_(1) #要计算的是生存的图片与真实的loss,所以是1
noise.data.normal_(0,1)#产生0-1的高斯噪声
fake_pic = netg(noise)
output = netd(fake_pic)
loss_G = criterion(output.squeeze(),label)
loss_G.backward()
optimizerG.step()
# D_G_z2 = output.data.mean() if epoch%2 == 0:
fake_u = netg(fix_noise)
imgs = make_grid(fake_u.data*0.5+0.5).cpu() #chw
plt.imshow(imgs.permute(1,2,0).numpy())
plt.show()

手写DCGAN的更多相关文章

  1. 卷积生成对抗网络(DCGAN)---生成手写数字

    深度卷积生成对抗网络(DCGAN) ---- 生成 MNIST 手写图片 1.基本原理 生成对抗网络(GAN)由2个重要的部分构成: 生成器(Generator):通过机器生成数据(大部分情况下是图像 ...

  2. 【Win 10 应用开发】手写识别

    记得前面(忘了是哪天写的,反正是前些天,请用力点击这里观看)老周讲了一个14393新增的控件,可以很轻松地结合InkCanvas来完成涂鸦.其实,InkCanvas除了涂鸦外,另一个大用途是墨迹识别, ...

  3. JS / Egret 单笔手写识别、手势识别

    UnistrokeRecognizer 单笔手写识别.手势识别 UnistrokeRecognizer : https://github.com/RichLiu1023/UnistrokeRecogn ...

  4. 如何用卷积神经网络CNN识别手写数字集?

    前几天用CNN识别手写数字集,后来看到kaggle上有一个比赛是识别手写数字集的,已经进行了一年多了,目前有1179个有效提交,最高的是100%,我做了一下,用keras做的,一开始用最简单的MLP, ...

  5. 【转】机器学习教程 十四-利用tensorflow做手写数字识别

    模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...

  6. caffe_手写数字识别Lenet模型理解

    这两天看了Lenet的模型理解,很简单的手写数字CNN网络,90年代美国用它来识别钞票,准确率还是很高的,所以它也是一个很经典的模型.而且学习这个模型也有助于我们理解更大的网络比如Imagenet等等 ...

  7. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  8. 手写原生ajax

    关于手写原生ajax重要不重要,各位道友自己揣摩吧, 本着学习才能进步,分享大家共同受益,自己也在自己博客里写一下 function createXMLHTTPRequest() { //1.创建XM ...

  9. springmvc 动态代理 JDK实现与模拟JDK纯手写实现。

    首先明白 动态代理和静态代理的区别: 静态代理:①持有被代理类的引用  ② 代理类一开始就被加载到内存中了(非常重要) 动态代理:JDK中的动态代理中的代理类是动态生成的.并且生成的动态代理类为$Pr ...

随机推荐

  1. Wireshark分析之TCP协议(二)

    (1)TCP首部格式 源端口:   用来传输数据报的端口 目标端口: 数据包将要发送到的端口 序号: 用来表示一个TCP片段.这个值用来表示数据流中的部分数据没有丢失 确认号:  表示通信中希望从另一 ...

  2. B - Network---UVA 315(无向图求割点)

        A Telephone Line Company (TLC) is establishing a new telephone cable network. They are connectin ...

  3. 深入理解Docker容器执行引擎runC

    1 简介 根据官方的定义:runC是一个根据OCI标准创建并运行容器的CLI tool. Docker就是基于runC创建的,简单地说,runC就是docker中最为核心的部分,容器的创建,运行,销毁 ...

  4. (2.2)DDL增强功能-自定义函数与存储过程

    1.存储过程 精华总结: 通过对比@@ERROR一般和if判断结合使用,@@TRANCOUNT和try catch块结合使用,xact_abort为on可以单独使用Xact_abort为off时,如果 ...

  5. ruby中的**

    在ruby中,**是乘方的意思.它是一个右结合性的运算.如下: 在多个乘方的时候,会先进行后面的乘方运算,结果作为指数再与前一位进行乘方运算.

  6. 1linux的基本命令

    查看命令的帮助信息man 命令名 文件操作touch 建立文件 (对于已存在文件,更新时间)cat 查看文件 (-n 自动加上行号)rm 删除文件cp 拷贝文件mv 移动/重命名文件more 分页查看 ...

  7. 478. Generate Random Point in a Circle

    1. 问题 给定一个圆的半径和圆心坐标,生成圆内点的坐标. 2. 思路 简单说 (1)在圆内随机取点不好做,但是如果画出这个圆的外接正方形,在正方形里面采样就好做了. (2)取两个random确定正方 ...

  8. PYPI 国内源

    搬砖自http://www.cnblogs.com/sunnydou/p/5801760.html 阿里云 http://mirrors.aliyun.com/pypi/simple/ 中国科技大学  ...

  9. vue 基础笔记

    Vue01笔记 ES6模块使用和新的函数声明方式 a) Import 一定不能放在函数内, 建议放在上方 b) Export 除了声明式的以外, 尽量放在代码的下方 Import {name,age} ...

  10. R之内存管理

    引言 R的内存管理机制究竟是什么样子的?最近几日在讲一个分享会,被同学问到这方面的问题,可是到网上去查,终于找到一篇R语言内存管理不过讲的不清不楚的,就拿memory.limit()函数来说,是在wi ...