源码地址:https://github.com/mrzhu-cool/pix2pix-pytorch

相比于朱俊彦的版本,这一版更加简单易读

训练的代码在train.py,开头依然是很多代码的共同三板斧,加载参数,加载数据,加载模型

命令行参数

  1. # Training settings
  2. parser = argparse.ArgumentParser(description='pix2pix-pytorch-implementation')
  3. parser.add_argument('--dataset', required=True, help='facades')
  4. parser.add_argument('--batch_size', type=int, default=1, help='training batch size')
  5. parser.add_argument('--test_batch_size', type=int, default=1, help='testing batch size')
  6. parser.add_argument('--direction', type=str, default='b2a', help='a2b or b2a')
  7. parser.add_argument('--input_nc', type=int, default=3, help='input image channels')
  8. parser.add_argument('--output_nc', type=int, default=3, help='output image channels')
  9. parser.add_argument('--ngf', type=int, default=64, help='generator filters in first conv layer')
  10. parser.add_argument('--ndf', type=int, default=64, help='discriminator filters in first conv layer')
  11. parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count')
  12. parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
  13. parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
  14. parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
  15. parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine')
  16. parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
  17. parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
  18. parser.add_argument('--cuda', action='store_true', help='use cuda?')
  19. parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use')
  20. parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
  21. parser.add_argument('--lamb', type=int, default=10, help='weight on L1 term in objective')
  22. opt = parser.parse_args()

数据

  1. print('===> Loading datasets')
  2. root_path = "dataset/"
  3. train_set = get_training_set(root_path + opt.dataset, opt.direction)
  4. test_set = get_test_set(root_path + opt.dataset, opt.direction)
  5. training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True)
  6. testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.test_batch_size, shuffle=False)

模型

  1. print('===> Building models')
  2. net_g = define_G(opt.input_nc, opt.output_nc, opt.ngf, 'batch', False, 'normal', 0.02, gpu_id=device)
  3. net_d = define_D(opt.input_nc + opt.output_nc, opt.ndf, 'basic', gpu_id=device)

优化器,损失函数

  1. criterionGAN = GANLoss().to(device)
  2. criterionL1 = nn.L1Loss().to(device)
  3. criterionMSE = nn.MSELoss().to(device)
  4.  
  5. # setup optimizer
  6. optimizer_g = optim.Adam(net_g.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
  7. optimizer_d = optim.Adam(net_d.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
  8. net_g_scheduler = get_scheduler(optimizer_g, opt)
  9. net_d_scheduler = get_scheduler(optimizer_d, opt)

接着按批次读取数据,首先更新判别器,判别器的输入是图像对(真,真)(真,假)

  1. ######################
  2. # (1) Update D network
  3. ######################
  4.  
  5. optimizer_d.zero_grad()
  6.  
  7. # train with fake
  8. fake_ab = torch.cat((real_a, fake_b), 1)
  9. pred_fake = net_d.forward(fake_ab.detach())
  10. loss_d_fake = criterionGAN(pred_fake, False)
  11.  
  12. # train with real
  13. real_ab = torch.cat((real_a, real_b), 1)
  14. pred_real = net_d.forward(real_ab)
  15. loss_d_real = criterionGAN(pred_real, True)
  16.  
  17. # Combined D loss
  18. loss_d = (loss_d_fake + loss_d_real) * 0.5
  19.  
  20. loss_d.backward()
  21.  
  22. optimizer_d.step()

然后更新生成器,生成器的损失由判别器产生的损失函数和真假图像之间的L1约束组成

  1. ######################
  2. # (2) Update G network
  3. ######################
  4.  
  5. optimizer_g.zero_grad()
  6.  
  7. # First, G(A) should fake the discriminator
  8. fake_ab = torch.cat((real_a, fake_b), 1)
  9. pred_fake = net_d.forward(fake_ab)
  10. loss_g_gan = criterionGAN(pred_fake, True)
  11.  
  12. # Second, G(A) = B
  13. loss_g_l1 = criterionL1(fake_b, real_b) * opt.lamb
  14.  
  15. loss_g = loss_g_gan + loss_g_l1
  16.  
  17. loss_g.backward()
  18.  
  19. optimizer_g.step()

最后更新学习率

  1. update_learning_rate(net_g_scheduler, optimizer_g)
  2. update_learning_rate(net_d_scheduler, optimizer_d)

比较核心的代码是网络构造,以及一些工具函数,放在后面写

【源码解读】pix2pix(一):训练的更多相关文章

  1. Bert系列(二)——源码解读之模型主体

    本篇文章主要是解读模型主体代码modeling.py.在阅读这篇文章之前希望读者们对bert的相关理论有一定的了解,尤其是transformer的结构原理,网上的资料很多,本文内容对原理部分就不做过多 ...

  2. Bert系列(三)——源码解读之Pre-train

    https://www.jianshu.com/p/22e462f01d8c pre-train是迁移学习的基础,虽然Google已经发布了各种预训练好的模型,而且因为资源消耗巨大,自己再预训练也不现 ...

  3. [源码分析] Facebook如何训练超大模型 --- (2)

    [源码分析] Facebook如何训练超大模型 --- (2) 目录 [源码分析] Facebook如何训练超大模型 --- (2) 0x00 摘要 0x01 回顾 1.1 ZeRO 1.1.1 Ze ...

  4. SDWebImage源码解读之SDWebImageDownloaderOperation

    第七篇 前言 本篇文章主要讲解下载操作的相关知识,SDWebImageDownloaderOperation的主要任务是把一张图片从服务器下载到内存中.下载数据并不难,如何对下载这一系列的任务进行设计 ...

  5. SDWebImage源码解读 之 NSData+ImageContentType

    第一篇 前言 从今天开始,我将开启一段源码解读的旅途了.在这里先暂时不透露具体解读的源码到底是哪些?因为也可能随着解读的进行会更改计划.但能够肯定的是,这一系列之中肯定会有Swift版本的代码. 说说 ...

  6. SDWebImage源码解读 之 UIImage+GIF

    第二篇 前言 本篇是和GIF相关的一个UIImage的分类.主要提供了三个方法: + (UIImage *)sd_animatedGIFNamed:(NSString *)name ----- 根据名 ...

  7. SDWebImage源码解读 之 SDWebImageCompat

    第三篇 前言 本篇主要解读SDWebImage的配置文件.正如compat的定义,该配置文件主要是兼容Apple的其他设备.也许我们真实的开发平台只有一个,但考虑各个平台的兼容性,对于框架有着很重要的 ...

  8. SDWebImage源码解读_之SDWebImageDecoder

    第四篇 前言 首先,我们要弄明白一个问题? 为什么要对UIImage进行解码呢?难道不能直接使用吗? 其实不解码也是可以使用的,假如说我们通过imageNamed:来加载image,系统默认会在主线程 ...

  9. SDWebImage源码解读之SDWebImageCache(上)

    第五篇 前言 本篇主要讲解图片缓存类的知识,虽然只涉及了图片方面的缓存的设计,但思想同样适用于别的方面的设计.在架构上来说,缓存算是存储设计的一部分.我们把各种不同的存储内容按照功能进行切割后,图片缓 ...

  10. SDWebImage源码解读之SDWebImageCache(下)

    第六篇 前言 我们在SDWebImageCache(上)中了解了这个缓存类大概的功能是什么?那么接下来就要看看这些功能是如何实现的? 再次强调,不管是图片的缓存还是其他各种不同形式的缓存,在原理上都极 ...

随机推荐

  1. FJOI2017 day2游记

    day0 早上复习了一下凸包,lct的板子,发现现在的我好菜鸡啊,做题基本上还得看题解,自己不是很能分析出来. 下午去看考场,在附中机房又写了一遍lct,然后a掉了文理分科完就回去了. 回家的路上走在 ...

  2. 微信长按识别二维码,在 vue 项目中的实现

    微信长按识别二维码是 QQ 浏览器的内置功能,该功能的基础一定要使用 img 标签引入图片,其他方式的二维码无法识别. 在 vue 中使用 QrcodeVue 插件 demo1 在 template ...

  3. Qt 静态库与共享库(动态库)共享配置的一个小办法

    对于用 QtCreator 编写静态库,动态库,如何能够以最小的改动, 方便的实现两种形式的库文件生成:可以这麽做: 1)使用想到建立静态库 2)在项目配置文件*.pro  中: TARGET = n ...

  4. 1-window搭建git

    windows7搭建Git私服 作为版本控制工具大多公司会选用Git,但svn也具有一定的优势,在对开源项目管理方面,Git具有一定的优势,我们可以将自己的项目放到GitHub上面,供大家交流学习,但 ...

  5. Tomcat配置多域名 Alias

    在Tomcat配置多域名,目的是和apache相对应,实现多域名访问. 使用 < Alias></ Alias>,务必注意,使用的是首字母大写. 我刚开配置使用小写,如果hos ...

  6. java中的char,short,int,long占几个字节

    1:“字节”是byte,“位”是bit : 2: 1 byte = 8 bit : char 在java中是2个字节.java采用unicode,2个字节(16位)来表示一个字符. short 2个字 ...

  7. docker部署Javaweb项目(jdk+tomcat+mysql)

    步骤一:在主机下载安装docker,参照Centos7上安装docker 步骤二:下载Linux版本的JDK1.6和Tomcat6.0(其他项目若依赖其他版本的运行环境可选择另外版本下载),通过sec ...

  8. 从 sourcemap 中获取源码

    使用 paazmaya/shuji: Reverse engineering JavaScript and CSS sources from sourcemaps 可以从 sourcemap 中获取源 ...

  9. 前端必须掌握的 nginx 技能(1)

    概述 作为一个前端,我觉得必须要学会使用 nginx 干下面几件事: 代理静态资源 设置反向代理(添加https) 设置缓存 设置 log 部署 smtp 服务 设置 redis 缓存(选) 下面我按 ...

  10. python学习笔记:(四)tuple(元组)常用方法

    tuple(元组)的常用方法 1.del 删除元组 #del 删除元组 a=(1,2,3) del a print(a) 2.len() 计算元组中,值的个数 #len:计算元组元素的个数 a=(1, ...