//加上了注释,对pytorch又加深了理解
import torch as t
from torch import nn
from torch.autograd import Variable
from torch.optim import Adam
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from pylab import plt #pylab结合了pyplot和numpy class Config:
lr = 0.0002
nz = 100 #噪声维度
image_size = 64
image_size2 = 64
nc = 3 #图片是三通道的
ngf = 64 #G的特征层数
ndf = 64 #D的特征层数
beta1 = 0.5
batch_size = 32
max_epoch = 10
workers = 0
gpu = True opt = Config() #载入数据
transform = transforms.Compose([
transforms.Resize(opt.image_size),
transforms.ToTensor(),
transforms.Normalize([0.5]*3,[0.5]*3) #均值&标准差
]) dataset = CIFAR10(root='cifar10/',transform=transform,download=True)
dataloader = DataLoader(dataset,opt.batch_size,shuffle=True,num_workers=opt.workers) #输入的是噪声图片的维度
netg = nn.Sequential(
nn.ConvTranspose2d(opt.nz,opt.ngf*8,4,1,0,bias=False),
nn.BatchNorm2d(opt.ngf*8),
nn.ReLU(True), nn.ConvTranspose2d(opt.ngf*8,opt.ngf*4,4,2,1,bias=False),
nn.BatchNorm2d(opt.ngf*4),
nn.ReLU(True), nn.ConvTranspose2d(opt.ngf*4,opt.ngf*2,4,2,1,bias=False),
nn.BatchNorm2d(opt.ngf*2),
nn.ReLU(True), nn.ConvTranspose2d(opt.ngf*2,opt.ngf,4,2,1,bias=False),
nn.BatchNorm2d(opt.ngf),
nn.ReLU(True), nn.ConvTranspose2d(opt.ngf,opt.nc,4,2,1,bias=False),
nn.Tanh() #输出的是FAKE图片的维度
) netd = nn.Sequential(
nn.Conv2d(opt.nc,opt.ndf,4,2,1,bias=False),
nn.LeakyReLU(0.2,inplace=True), nn.Conv2d(opt.ndf, opt.ndf*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(opt.ndf*2),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(opt.ndf*2, opt.ndf*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(opt.ndf * 4),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(opt.ndf*4, opt.ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(opt.ndf * 8),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(opt.ndf*8,1,4,1,0,bias=False),
nn.Sigmoid()
) #optimizer
optimizerD = Adam(netd.parameters(),lr = opt.lr,betas=(opt.beta1,0.999))
optimizerG = Adam(netg.parameters(),lr = opt.lr,betas=(opt.beta1,0.999)) #criterion
criterion = nn.BCELoss() fix_noise = Variable(t.FloatTensor(opt.batch_size,opt.nz,1,1).normal_(0,1))#高斯分布N(0,1)
if opt.gpu:
fix_noise = fix_noise.cuda()
netd.cuda()
netg.cuda()
criterion.cuda() print("开始训练") for epoch in range(opt.max_epoch):
for ii,data in enumerate(dataloader,start=0):
real,_ = data
input = Variable(real)
label = Variable(t.ones(input.size(0)))#一开始训练DIS用real image 所以给的label都是1,所以这个label大小和batch_size大小一样
noise = t.randn(input.size(0),opt.nz,1,1)#不是很理解后面两个1是干啥用的
noise = Variable(noise) if opt.gpu:
noise = noise.cuda()
input = input.cuda()
label = label.cuda() #____train disc____
netd.zero_grad()
#用real image train
output = netd(input)
loss_real = criterion(output.squeeze(),label)#output 与 1之间的loss
loss_real.backward()
# D_x = output.data.mean()#这是平均loss
#用fake image train
fake_pic = netg(noise).detach()#截断反向传播,只影响G不影响D
output2 = netd(fake_pic)
label.data.fill_(0) #把label的1改成0,因为是fake image
loss_fake = criterion(output2.squeeze(),label)
loss_fake.backward()
# D_x2 = output2.data.mean()
error_D = loss_real+loss_fake
optimizerD.step() #_____train generator__
netg.zero_grad()
label.data.fill_(1) #要计算的是生存的图片与真实的loss,所以是1
noise.data.normal_(0,1)#产生0-1的高斯噪声
fake_pic = netg(noise)
output = netd(fake_pic)
loss_G = criterion(output.squeeze(),label)
loss_G.backward()
optimizerG.step()
# D_G_z2 = output.data.mean() if epoch%2 == 0:
fake_u = netg(fix_noise)
imgs = make_grid(fake_u.data*0.5+0.5).cpu() #chw
plt.imshow(imgs.permute(1,2,0).numpy())
plt.show()

手写DCGAN的更多相关文章

  1. 卷积生成对抗网络(DCGAN)---生成手写数字

    深度卷积生成对抗网络(DCGAN) ---- 生成 MNIST 手写图片 1.基本原理 生成对抗网络(GAN)由2个重要的部分构成: 生成器(Generator):通过机器生成数据(大部分情况下是图像 ...

  2. 【Win 10 应用开发】手写识别

    记得前面(忘了是哪天写的,反正是前些天,请用力点击这里观看)老周讲了一个14393新增的控件,可以很轻松地结合InkCanvas来完成涂鸦.其实,InkCanvas除了涂鸦外,另一个大用途是墨迹识别, ...

  3. JS / Egret 单笔手写识别、手势识别

    UnistrokeRecognizer 单笔手写识别.手势识别 UnistrokeRecognizer : https://github.com/RichLiu1023/UnistrokeRecogn ...

  4. 如何用卷积神经网络CNN识别手写数字集?

    前几天用CNN识别手写数字集,后来看到kaggle上有一个比赛是识别手写数字集的,已经进行了一年多了,目前有1179个有效提交,最高的是100%,我做了一下,用keras做的,一开始用最简单的MLP, ...

  5. 【转】机器学习教程 十四-利用tensorflow做手写数字识别

    模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...

  6. caffe_手写数字识别Lenet模型理解

    这两天看了Lenet的模型理解,很简单的手写数字CNN网络,90年代美国用它来识别钞票,准确率还是很高的,所以它也是一个很经典的模型.而且学习这个模型也有助于我们理解更大的网络比如Imagenet等等 ...

  7. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  8. 手写原生ajax

    关于手写原生ajax重要不重要,各位道友自己揣摩吧, 本着学习才能进步,分享大家共同受益,自己也在自己博客里写一下 function createXMLHTTPRequest() { //1.创建XM ...

  9. springmvc 动态代理 JDK实现与模拟JDK纯手写实现。

    首先明白 动态代理和静态代理的区别: 静态代理:①持有被代理类的引用  ② 代理类一开始就被加载到内存中了(非常重要) 动态代理:JDK中的动态代理中的代理类是动态生成的.并且生成的动态代理类为$Pr ...

随机推荐

  1. spriing boot 启动报错:Cannot determine embedded database driver class for database type NONE

    最近在学习使用spring boot.使用maven创建好工程,只引用需要用到的spring boot相关的jar包,除此之外没有任何的配置. 写了一个最简单的例子,如下所示: package com ...

  2. Linux下tomcat启动项目原因排查

    先停掉tomcat服务器: 然后把文件删除: 这时候启动服务器: 看下有没有启动成功: 接着把重新优化过的代码用X ftp传上去. 等几分钟就可以. 如果老是出现问题,就去catalina.out文件 ...

  3. PHP debug_backtrace() 函数打印调用处的调试信息

    http://php.net/manual/zh/function.debug-backtrace.php debug_backtrace (PHP 4 >= 4.3.0, PHP 5, PHP ...

  4. 使用jackson工具类把对象或集合转为JSON格式

    jackson使用方法: 1.加入jar包: jackson-annotations-2.2.2.jar jackson-core-2.2.2.jar jackson-databind-2.2.2.j ...

  5. R语言统计词频 画词云

    原始数据: 程序: #统计词频 library(wordcloud) # F:/master2017/ch4/weibo170.cut.txt text <- readLines("F ...

  6. VS2010/MFC编程入门之七(对话框:为对话框添加控件)

    创建对话框资源需要创建对话框模板.修改对话框属性.为对话框添加各种控件等步骤,前面一讲中鸡啄米已经讲了创建对话框模板和修改对话框属性,本节继续讲如何为对话框添加控件. 上一讲中鸡啄米创建了一个名为“A ...

  7. 大数据,why python

    大数据,why python ps, 2015-12-4 20:47:46 python" title="大数据,why python">http://www.op ...

  8. 分布式session的管理

    在分布式架构或微服务架构下,必须保证一个应用服务器上保存Session后,其它应用服务器可以同步或共享这个Session,可能会出现在A1系统登录后创建并保存Session,再次发起请求,请求被转发到 ...

  9. webpack踩过的坑(总结)

    使用process.argv 获取命令行使用的参数 // 判断是否带production参数,production会压缩js var isprod = false; for (var i in pro ...

  10. 【知识总结】Activiti工作流学习入门

    1. 我理解的工作流: 在工作中慢慢接触的业务流程,就向流程控制语言一样,一步一步都对应的不同的业务,但整体串联起来就是一个完整的业务.而且实际工作中尤其是在企业内部系统的研发中,确实需要对应许多审批 ...