//加上了注释,对pytorch又加深了理解
import torch as t
from torch import nn
from torch.autograd import Variable
from torch.optim import Adam
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from pylab import plt #pylab结合了pyplot和numpy class Config:
lr = 0.0002
nz = 100 #噪声维度
image_size = 64
image_size2 = 64
nc = 3 #图片是三通道的
ngf = 64 #G的特征层数
ndf = 64 #D的特征层数
beta1 = 0.5
batch_size = 32
max_epoch = 10
workers = 0
gpu = True opt = Config() #载入数据
transform = transforms.Compose([
transforms.Resize(opt.image_size),
transforms.ToTensor(),
transforms.Normalize([0.5]*3,[0.5]*3) #均值&标准差
]) dataset = CIFAR10(root='cifar10/',transform=transform,download=True)
dataloader = DataLoader(dataset,opt.batch_size,shuffle=True,num_workers=opt.workers) #输入的是噪声图片的维度
netg = nn.Sequential(
nn.ConvTranspose2d(opt.nz,opt.ngf*8,4,1,0,bias=False),
nn.BatchNorm2d(opt.ngf*8),
nn.ReLU(True), nn.ConvTranspose2d(opt.ngf*8,opt.ngf*4,4,2,1,bias=False),
nn.BatchNorm2d(opt.ngf*4),
nn.ReLU(True), nn.ConvTranspose2d(opt.ngf*4,opt.ngf*2,4,2,1,bias=False),
nn.BatchNorm2d(opt.ngf*2),
nn.ReLU(True), nn.ConvTranspose2d(opt.ngf*2,opt.ngf,4,2,1,bias=False),
nn.BatchNorm2d(opt.ngf),
nn.ReLU(True), nn.ConvTranspose2d(opt.ngf,opt.nc,4,2,1,bias=False),
nn.Tanh() #输出的是FAKE图片的维度
) netd = nn.Sequential(
nn.Conv2d(opt.nc,opt.ndf,4,2,1,bias=False),
nn.LeakyReLU(0.2,inplace=True), nn.Conv2d(opt.ndf, opt.ndf*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(opt.ndf*2),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(opt.ndf*2, opt.ndf*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(opt.ndf * 4),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(opt.ndf*4, opt.ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(opt.ndf * 8),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(opt.ndf*8,1,4,1,0,bias=False),
nn.Sigmoid()
) #optimizer
optimizerD = Adam(netd.parameters(),lr = opt.lr,betas=(opt.beta1,0.999))
optimizerG = Adam(netg.parameters(),lr = opt.lr,betas=(opt.beta1,0.999)) #criterion
criterion = nn.BCELoss() fix_noise = Variable(t.FloatTensor(opt.batch_size,opt.nz,1,1).normal_(0,1))#高斯分布N(0,1)
if opt.gpu:
fix_noise = fix_noise.cuda()
netd.cuda()
netg.cuda()
criterion.cuda() print("开始训练") for epoch in range(opt.max_epoch):
for ii,data in enumerate(dataloader,start=0):
real,_ = data
input = Variable(real)
label = Variable(t.ones(input.size(0)))#一开始训练DIS用real image 所以给的label都是1,所以这个label大小和batch_size大小一样
noise = t.randn(input.size(0),opt.nz,1,1)#不是很理解后面两个1是干啥用的
noise = Variable(noise) if opt.gpu:
noise = noise.cuda()
input = input.cuda()
label = label.cuda() #____train disc____
netd.zero_grad()
#用real image train
output = netd(input)
loss_real = criterion(output.squeeze(),label)#output 与 1之间的loss
loss_real.backward()
# D_x = output.data.mean()#这是平均loss
#用fake image train
fake_pic = netg(noise).detach()#截断反向传播,只影响G不影响D
output2 = netd(fake_pic)
label.data.fill_(0) #把label的1改成0,因为是fake image
loss_fake = criterion(output2.squeeze(),label)
loss_fake.backward()
# D_x2 = output2.data.mean()
error_D = loss_real+loss_fake
optimizerD.step() #_____train generator__
netg.zero_grad()
label.data.fill_(1) #要计算的是生存的图片与真实的loss,所以是1
noise.data.normal_(0,1)#产生0-1的高斯噪声
fake_pic = netg(noise)
output = netd(fake_pic)
loss_G = criterion(output.squeeze(),label)
loss_G.backward()
optimizerG.step()
# D_G_z2 = output.data.mean() if epoch%2 == 0:
fake_u = netg(fix_noise)
imgs = make_grid(fake_u.data*0.5+0.5).cpu() #chw
plt.imshow(imgs.permute(1,2,0).numpy())
plt.show()

手写DCGAN的更多相关文章

  1. 卷积生成对抗网络(DCGAN)---生成手写数字

    深度卷积生成对抗网络(DCGAN) ---- 生成 MNIST 手写图片 1.基本原理 生成对抗网络(GAN)由2个重要的部分构成: 生成器(Generator):通过机器生成数据(大部分情况下是图像 ...

  2. 【Win 10 应用开发】手写识别

    记得前面(忘了是哪天写的,反正是前些天,请用力点击这里观看)老周讲了一个14393新增的控件,可以很轻松地结合InkCanvas来完成涂鸦.其实,InkCanvas除了涂鸦外,另一个大用途是墨迹识别, ...

  3. JS / Egret 单笔手写识别、手势识别

    UnistrokeRecognizer 单笔手写识别.手势识别 UnistrokeRecognizer : https://github.com/RichLiu1023/UnistrokeRecogn ...

  4. 如何用卷积神经网络CNN识别手写数字集?

    前几天用CNN识别手写数字集,后来看到kaggle上有一个比赛是识别手写数字集的,已经进行了一年多了,目前有1179个有效提交,最高的是100%,我做了一下,用keras做的,一开始用最简单的MLP, ...

  5. 【转】机器学习教程 十四-利用tensorflow做手写数字识别

    模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...

  6. caffe_手写数字识别Lenet模型理解

    这两天看了Lenet的模型理解,很简单的手写数字CNN网络,90年代美国用它来识别钞票,准确率还是很高的,所以它也是一个很经典的模型.而且学习这个模型也有助于我们理解更大的网络比如Imagenet等等 ...

  7. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  8. 手写原生ajax

    关于手写原生ajax重要不重要,各位道友自己揣摩吧, 本着学习才能进步,分享大家共同受益,自己也在自己博客里写一下 function createXMLHTTPRequest() { //1.创建XM ...

  9. springmvc 动态代理 JDK实现与模拟JDK纯手写实现。

    首先明白 动态代理和静态代理的区别: 静态代理:①持有被代理类的引用  ② 代理类一开始就被加载到内存中了(非常重要) 动态代理:JDK中的动态代理中的代理类是动态生成的.并且生成的动态代理类为$Pr ...

随机推荐

  1. 2017ACM-ICPC沈阳区域赛

    I-Little Boxes[大数] hdu6225  http://acm.hdu.edu.cn/showproblem.php?pid=6225 题意: 就是给四个大数,输出和. 思路: java ...

  2. ubuntu打开txt乱码

    因为不支持中文 输入命令: iconv -f gbk -t utf8 filename.txt > filename.txt.utf8

  3. vue - nodejs

    一.知识 打开Nodejs英文网:https://nodejs.org/en/ 中文网:http://nodejs.cn/ 我们会发现这样一句话: 翻译成中文如下: Node.js 是一个基于 Chr ...

  4. Scala数组和集合

    一.scala数组 数组定义1: var arr = new Array[String](3) String:存储的元素类型 3:存储3个元素 添加元素: arr(1) = "hello&q ...

  5. 广通软件获“2016年度中国最具影响力IT运维管理软件提供商”殊荣

    12月16日,“科技原力觉醒引领创新巅峰”-- 2016创新影响力年会暨国家产业服务平台•2016年终评活动在北京裕龙国际酒店落下帷幕. 本活动在主管部门的指导参与下,总结本年度技术成果并籍此对未来科 ...

  6. JSON数组成员反序列化

    场景: 构想客户端能够传递如下格式JSON字符串到服务端: {"KeyValueSetList":[{"SN":"RQ1001"," ...

  7. centos 阶段复习 2015-4-6 dd命令 hosts.allow和hosts.deny 啊铭的myssh脚本 清空history命令历史 /dev/zero 零发生器 /dev/null 黑洞 /dev/random 生成随机数 第十一节课

    centos 阶段复习 2015-4-6  dd命令 hosts.allow和hosts.deny 啊铭的myssh脚本 清空history命令历史  /dev/zero 零发生器  /dev/nul ...

  8. EasyUI Draggable 可拖动

    通过 $.fn.draggable.defaults 重写默认的 defaults. 用法 通过标记创建可拖动(draggable)元素. <div id="dd" clas ...

  9. PAT 1102 Invert a Binary Tree[比较简单]

    1102 Invert a Binary Tree(25 分) The following is from Max Howell @twitter: Google: 90% of our engine ...

  10. svn命令行使用

    1.将文件checkout到本地目录    svn checkout path(path是服务器上的目录)    例如:svn checkout svn://192.168.1.1/pro/domai ...