手写DCGAN
//加上了注释,对pytorch又加深了理解
import torch as t
from torch import nn
from torch.autograd import Variable
from torch.optim import Adam
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from pylab import plt #pylab结合了pyplot和numpy class Config:
lr = 0.0002
nz = 100 #噪声维度
image_size = 64
image_size2 = 64
nc = 3 #图片是三通道的
ngf = 64 #G的特征层数
ndf = 64 #D的特征层数
beta1 = 0.5
batch_size = 32
max_epoch = 10
workers = 0
gpu = True opt = Config() #载入数据
transform = transforms.Compose([
transforms.Resize(opt.image_size),
transforms.ToTensor(),
transforms.Normalize([0.5]*3,[0.5]*3) #均值&标准差
]) dataset = CIFAR10(root='cifar10/',transform=transform,download=True)
dataloader = DataLoader(dataset,opt.batch_size,shuffle=True,num_workers=opt.workers) #输入的是噪声图片的维度
netg = nn.Sequential(
nn.ConvTranspose2d(opt.nz,opt.ngf*8,4,1,0,bias=False),
nn.BatchNorm2d(opt.ngf*8),
nn.ReLU(True), nn.ConvTranspose2d(opt.ngf*8,opt.ngf*4,4,2,1,bias=False),
nn.BatchNorm2d(opt.ngf*4),
nn.ReLU(True), nn.ConvTranspose2d(opt.ngf*4,opt.ngf*2,4,2,1,bias=False),
nn.BatchNorm2d(opt.ngf*2),
nn.ReLU(True), nn.ConvTranspose2d(opt.ngf*2,opt.ngf,4,2,1,bias=False),
nn.BatchNorm2d(opt.ngf),
nn.ReLU(True), nn.ConvTranspose2d(opt.ngf,opt.nc,4,2,1,bias=False),
nn.Tanh() #输出的是FAKE图片的维度
) netd = nn.Sequential(
nn.Conv2d(opt.nc,opt.ndf,4,2,1,bias=False),
nn.LeakyReLU(0.2,inplace=True), nn.Conv2d(opt.ndf, opt.ndf*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(opt.ndf*2),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(opt.ndf*2, opt.ndf*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(opt.ndf * 4),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(opt.ndf*4, opt.ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(opt.ndf * 8),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(opt.ndf*8,1,4,1,0,bias=False),
nn.Sigmoid()
) #optimizer
optimizerD = Adam(netd.parameters(),lr = opt.lr,betas=(opt.beta1,0.999))
optimizerG = Adam(netg.parameters(),lr = opt.lr,betas=(opt.beta1,0.999)) #criterion
criterion = nn.BCELoss() fix_noise = Variable(t.FloatTensor(opt.batch_size,opt.nz,1,1).normal_(0,1))#高斯分布N(0,1)
if opt.gpu:
fix_noise = fix_noise.cuda()
netd.cuda()
netg.cuda()
criterion.cuda() print("开始训练") for epoch in range(opt.max_epoch):
for ii,data in enumerate(dataloader,start=0):
real,_ = data
input = Variable(real)
label = Variable(t.ones(input.size(0)))#一开始训练DIS用real image 所以给的label都是1,所以这个label大小和batch_size大小一样
noise = t.randn(input.size(0),opt.nz,1,1)#不是很理解后面两个1是干啥用的
noise = Variable(noise) if opt.gpu:
noise = noise.cuda()
input = input.cuda()
label = label.cuda() #____train disc____
netd.zero_grad()
#用real image train
output = netd(input)
loss_real = criterion(output.squeeze(),label)#output 与 1之间的loss
loss_real.backward()
# D_x = output.data.mean()#这是平均loss
#用fake image train
fake_pic = netg(noise).detach()#截断反向传播,只影响G不影响D
output2 = netd(fake_pic)
label.data.fill_(0) #把label的1改成0,因为是fake image
loss_fake = criterion(output2.squeeze(),label)
loss_fake.backward()
# D_x2 = output2.data.mean()
error_D = loss_real+loss_fake
optimizerD.step() #_____train generator__
netg.zero_grad()
label.data.fill_(1) #要计算的是生存的图片与真实的loss,所以是1
noise.data.normal_(0,1)#产生0-1的高斯噪声
fake_pic = netg(noise)
output = netd(fake_pic)
loss_G = criterion(output.squeeze(),label)
loss_G.backward()
optimizerG.step()
# D_G_z2 = output.data.mean() if epoch%2 == 0:
fake_u = netg(fix_noise)
imgs = make_grid(fake_u.data*0.5+0.5).cpu() #chw
plt.imshow(imgs.permute(1,2,0).numpy())
plt.show()
手写DCGAN的更多相关文章
- 卷积生成对抗网络(DCGAN)---生成手写数字
深度卷积生成对抗网络(DCGAN) ---- 生成 MNIST 手写图片 1.基本原理 生成对抗网络(GAN)由2个重要的部分构成: 生成器(Generator):通过机器生成数据(大部分情况下是图像 ...
- 【Win 10 应用开发】手写识别
记得前面(忘了是哪天写的,反正是前些天,请用力点击这里观看)老周讲了一个14393新增的控件,可以很轻松地结合InkCanvas来完成涂鸦.其实,InkCanvas除了涂鸦外,另一个大用途是墨迹识别, ...
- JS / Egret 单笔手写识别、手势识别
UnistrokeRecognizer 单笔手写识别.手势识别 UnistrokeRecognizer : https://github.com/RichLiu1023/UnistrokeRecogn ...
- 如何用卷积神经网络CNN识别手写数字集?
前几天用CNN识别手写数字集,后来看到kaggle上有一个比赛是识别手写数字集的,已经进行了一年多了,目前有1179个有效提交,最高的是100%,我做了一下,用keras做的,一开始用最简单的MLP, ...
- 【转】机器学习教程 十四-利用tensorflow做手写数字识别
模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...
- caffe_手写数字识别Lenet模型理解
这两天看了Lenet的模型理解,很简单的手写数字CNN网络,90年代美国用它来识别钞票,准确率还是很高的,所以它也是一个很经典的模型.而且学习这个模型也有助于我们理解更大的网络比如Imagenet等等 ...
- 使用神经网络来识别手写数字【译】(三)- 用Python代码实现
实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...
- 手写原生ajax
关于手写原生ajax重要不重要,各位道友自己揣摩吧, 本着学习才能进步,分享大家共同受益,自己也在自己博客里写一下 function createXMLHTTPRequest() { //1.创建XM ...
- springmvc 动态代理 JDK实现与模拟JDK纯手写实现。
首先明白 动态代理和静态代理的区别: 静态代理:①持有被代理类的引用 ② 代理类一开始就被加载到内存中了(非常重要) 动态代理:JDK中的动态代理中的代理类是动态生成的.并且生成的动态代理类为$Pr ...
随机推荐
- ubuntu16.04下安装opencv-nonfree
在写计算机视觉与导航技术的课程作业,是关于sift和surf特征的提取及匹配.因为opencv中都有直接的函数可以调用. 关于SIFT和SURF的特征在opencv的nonfree模块中,从字面意思就 ...
- JavaCSV之写CSV文件
与JavaCSV读CSV文件相对应,JavaCSV也可以用来写数据到CSV文件中. 1.准备工作 (1)第三方包库下载地址:https://sourceforge.net/projects/javac ...
- 探究 Oracle 高水位对数据库性能影响
在开始深入分析之前,让我们先来了解一下高水位线 HWM. 一. HWM 的基本原理 (概念) 在 Oracle 中,高水位线(High-warter mark, HWM)被用来形容数据块的使用位置,即 ...
- talib 中文文档(九):# Volatility Indicator Functions 波动率指标函数
Volatility Indicator Functions 波动率指标函数 ATR - Average True Range 函数名:ATR 名称:真实波动幅度均值 简介:真实波动幅度均值(ATR) ...
- python基础之函数式编程、匿名函数、内置函数
一 函数式编程 不修改外部状态. 模仿数学里得函数进行编程. 用函数编程写出得代码相当精简. 可读性比较差. 例子: y=2*x+1 x=1 def test(x): return 2*x+1 tes ...
- 锁、volatile、CAS 比较
一.锁的劣势 (1) 在JDK1.5之前都是使用synchronized关键字保证同步的,这种通过使用一致的锁定协议来协调对共享状态的访问,可以确保无论哪个线程持有守 护变量的锁,都采用独占的方式来访 ...
- 数据库触发器,禁止DDL操作
CREATE TRIGGER [Object_Change_Trigger_DDL] ON DATABASE FOR ALTER_TABLE,DROP_TABLE,CREATE_TABLE,CREAT ...
- tools-eclipse-004-UML图安装
git:https://github.com/takezoe/amateras-modeler 下载:http://sourceforge.jp/projects/amateras/downloads ...
- js-jquery-SweetAlert2【二】配置与方法
一.配置 Argument Default value Description title null 模态对话框的标题.它可以在参数对象的title参数中设置,也可以在swal()方法的第一个参数 ...
- ThinkPhp3.2.3 多项目 后台 APP接口设计 框架设计
↓↓↓项目文件组成部分↓↓↓ APP文件是后台,index.php是入口文件 Interface文件是接口,注意这里不要用api命名!可能会有问题!interface.php是入口文件 注:两个入口文 ...