【源码解读】cycleGAN(二) :训练
源码地址:https://github.com/aitorzip/PyTorch-CycleGAN
训练的代码见于train.py,首先定义好网络,两个生成器A2B, B2A和两个判别器A, B,以及对应的优化器(优化器的设置保证了只更新生成器或判别器,不会互相影响)
###### Definition of variables ######
# Networks
netG_A2B = Generator(opt.input_nc, opt.output_nc)
netG_B2A = Generator(opt.output_nc, opt.input_nc)
netD_A = Discriminator(opt.input_nc)
netD_B = Discriminator(opt.output_nc)
# Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
lr=opt.lr, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999))
然后是数据
# Dataset loader
transforms_ = [ transforms.Resize(int(opt.size*1.12), Image.BICUBIC),
transforms.RandomCrop(opt.size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True),
batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu)
接着就可以求取损失,反传梯度,更新网络,更新网络的时候首先更新生成器,然后分别更新两个判别器
生成器:损失函数=身份损失+对抗损失+循环一致损失
###### Generators A2B and B2A ######
optimizer_G.zero_grad() # Identity loss
# G_A2B(B) should equal B if real B is fed
same_B = netG_A2B(real_B)
loss_identity_B = criterion_identity(same_B, real_B)*5.0
# G_B2A(A) should equal A if real A is fed
same_A = netG_B2A(real_A)
loss_identity_A = criterion_identity(same_A, real_A)*5.0 # GAN loss
fake_B = netG_A2B(real_A)
pred_fake = netD_B(fake_B)
loss_GAN_A2B = criterion_GAN(pred_fake, target_real) fake_A = netG_B2A(real_B)
pred_fake = netD_A(fake_A)
loss_GAN_B2A = criterion_GAN(pred_fake, target_real) # Cycle loss
recovered_A = netG_B2A(fake_B)
loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*10.0 recovered_B = netG_A2B(fake_A)
loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*10.0 # Total loss
loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
loss_G.backward() optimizer_G.step()
判别器A 损失函数= 真实样本分类损失 + 虚假样本分类损失
###### Discriminator A ######
optimizer_D_A.zero_grad() # Real loss
pred_real = netD_A(real_A)
loss_D_real = criterion_GAN(pred_real, target_real) # Fake loss
fake_A = fake_A_buffer.push_and_pop(fake_A)
pred_fake = netD_A(fake_A.detach())
loss_D_fake = criterion_GAN(pred_fake, target_fake) # Total loss
loss_D_A = (loss_D_real + loss_D_fake)*0.5
loss_D_A.backward() optimizer_D_A.step()
###################################
判别器B 损失函数= 真实样本分类损失 + 虚假样本分类损失
###### Discriminator B ######
optimizer_D_B.zero_grad() # Real loss
pred_real = netD_B(real_B)
loss_D_real = criterion_GAN(pred_real, target_real) # Fake loss
fake_B = fake_B_buffer.push_and_pop(fake_B)
pred_fake = netD_B(fake_B.detach())
loss_D_fake = criterion_GAN(pred_fake, target_fake) # Total loss
loss_D_B = (loss_D_real + loss_D_fake)*0.5
loss_D_B.backward() optimizer_D_B.step()
###################################
可以注意到,判别器损失中,虚假样本fake_A,fake_B都采用detach()操作,脱离计算图,这样判别器的损失进行反向传播不会对整个网络计算梯度,避免了不必要的计算
【源码解读】cycleGAN(二) :训练的更多相关文章
- YYModel 源码解读(二)之NSObject+YYModel.h (1)
本篇文章主要介绍 _YYModelPropertyMeta 前边的内容 首先先解释一下前边的辅助函数和枚举变量,在写一个功能的时候,这些辅助的东西可能不是一开始就能想出来的,应该是在后续的编码过程中 ...
- redux源码解读(二)
之前,已经写过一篇redux源码解读(一),主要分析了 redux 的核心思想,并用100多行代码实现一个简单的 redux .但是,那个实现还不具备合并 reducer 和添加 middleware ...
- swoft| 源码解读系列二: 启动阶段, swoft 都干了些啥?
date: 2018-8-01 14:22:17title: swoft| 源码解读系列二: 启动阶段, swoft 都干了些啥?description: 阅读 sowft 框架源码, 了解 sowf ...
- Spark学习之路 (十六)SparkCore的源码解读(二)spark-submit提交脚本
一.概述 上一篇主要是介绍了spark启动的一些脚本,这篇主要分析一下Spark源码中提交任务脚本的处理逻辑,从spark-submit一步步深入进去看看任务提交的整体流程,首先看一下整体的流程概要图 ...
- 【原】SparkContex源码解读(二)
版权声明:本文为原创文章,未经允许不得转载. 继续前一篇的内容.前一篇内容为: SparkContex源码解读(一)http://www.cnblogs.com/yourarebest/p/53266 ...
- Alamofire源码解读系列(二)之错误处理(AFError)
本篇主要讲解Alamofire中错误的处理机制 前言 在开发中,往往最容易被忽略的内容就是对错误的处理.有经验的开发者,能够对自己写的每行代码负责,而且非常清楚自己写的代码在什么时候会出现异常,这样就 ...
- ReactiveCocoa源码解读(二)
上一篇解读了ReactiveCocoa的三个重要的类的底层实现,本篇继续. 一.RACMulticastConnection 1.应用 RACMulticastConnection: 用于当一个信号被 ...
- YYModel 源码解读(二)之YYClassInfo.h (3)
前边3篇介绍了YYClassinfo 文件的组成单元,算是功能的分割,按照业务的设计思想来说,方向应该是相反的 由此引申出我们在设计api的思想其实和项目管理是很类似的----- 一些题外话 1.目的 ...
- PhotoSwipe源码解读系列(二)
作者: 铁锚 日期: 2013年12月19日 说明: 本系列文章为草稿,等待后期完善.源码是jQuery版本的,code.photoswipe-3.0.5.js 1. 代码开头,就是一些版权申明,没什 ...
- Netty源码解读(二)-服务端源码讲解
简单Echo案例 注释版代码地址:netty 代码是netty的源码,我添加了自己理解的中文注释. 了解了Netty的线程模型和组件之后,我们先看看如何写一个简单的Echo案例,后续的源码讲解都基于此 ...
随机推荐
- 可持久化Trie模板
如果你了解过 01 Trie 和 可持久化线段树(例如 : 主席树 ).那么就比较好去可持久化 Trie 可持久化 Trie 当 01 Trie 用的时候能很方便解决一些原本 01 Trie 不能解决 ...
- 视图:setContentView()
1.setContentView的作用是将View加载到根view之上,这样当显示view时,先显示根view,然后在显示子view,以此类推,最终将所有view显示出来. 2.setContentV ...
- Windows和Linux下搭建J2sdk的环境
J2SDK 作为jsp系统配置中必不可少的组件,越来越多的得到应用.下来是我整理的以往工作时搜集的资料.使用时方便查询,希望对广大的工程师有帮助. windows服务器环境下 j2sdk 的安装和环境 ...
- (67)c++后台开发
还记得自己在学校的时候,一直都比较注重的是:编程语言+数据结构与算法.没错,对于一个在校的计算机专业的学生,这是很重要的方面.但是,这往往不够,或许是因为毕业前一直没有进入企业实习,以至于自己在毕业之 ...
- PTA编程总结三
7-1 抓老鼠啊~亏了还是赚了? (20 分) 某地老鼠成灾,现悬赏抓老鼠,每抓到一只奖励10元,于是开始跟老鼠斗智斗勇:每天在墙角可选择以下三个操作:放置一个带有一块奶酪的捕鼠夹(T),或者放置一块 ...
- ppapi,npapi
PPAPI也就是Pepper Plugin API,是在原有网景NPAPI(Netscape Plugin API)基础上发展而来的.NPAPI是当今最流行的插件架构,几乎所有浏览器都支持,不过存在很 ...
- 1、Shiro简介以及整体架构
1.Shiro概念和作用: 利用shiro可以快速完成权限管理模块的开发 Spring的官网也是用Shiro做安全管理的... Shiro整体架构: 可能你感觉上面的图片很乱,但是你一定要先大体有个印 ...
- css随笔记(持续更新)
/*DIV鼠标穿透*/ div{pointer-events:none;} /*清除IE11默认×*/ input::-ms-clear{display:none;} 使用伪类写边框部分三角 右上角三 ...
- Mysql数据库存储数据时间与系统获取时间不一致
最近进行项目开发,发现存在数据库内的数据和系统查询到的数据相差8小时 发现有2种比较合适的方法 (一)修改mysql服务的区时 centos进入mysql查看/修改时区 1.输入以下命令进入mysql ...
- 【ABAP系列】SAP ABAP 行列转换的方法
公众号:SAP Technical 本文作者:matinal 原文出处:http://www.cnblogs.com/SAPmatinal/ 原文链接:[ABAP系列]SAP ABAP 行列转换的方法 ...