DA部分

输入图片大小:

images.size: torch.Size([1, 3, 512, 1024])
labels.size: torch.Size([1, 512, 1024])

input_size = (w, h) # input_size : <class 'tuple'>: (1024, 512)

input_size_target = (w, h) # <class 'tuple'>: (1024, 512)

分割后的特征图大小:

feat_source: ([1, 2048, 65, 129])
pred_source: ([1, 19, 65, 129])

pred_source = interp(pred_source) 上采样后 pred_source 大小变成: ([1, 19, 512, 1024])


创建网络:
1 model = DeeplabMulti(num_classes=args.num_classes)
2 def DeeplabMulti(num_classes=21):
3 model = ResNetMulti(Bottleneck, [3, 4, 23, 2, 1], num_classes)
4 return model
包含注意力的分割网络:
1 class ResNetMulti(nn.Module):
2
3 def forward(self, x, D, domain): # 源域进来就正常打分, 目标域进来就先加权后打分
4 x = self.conv1(x)
5 x = self.bn1(x)
6 x = self.relu(x)
7 x = self.maxpool(x)
8 x1 = self.layer1(x)
9 x2 = self.layer2(x1)
10 x3 = self.layer3(x2)
11 x4 = self.layer4(x3) # ft或者fs
12 if domain == 'source': # source:x4.size: torch.Size([1, 2048, 65, 129]) out.size: torch.Size([1, 19, 65, 129])
13 x4_a4 = x4
14 # 目标域 注意力图加权
15 if domain == 'target': # target:x4.size: torch.Size([1, 2048, 65, 129]) out.size: torch.Size([1, 19, 65, 129])
16 a4 = D[0](x4) #a4 等同于论文中的D(ft) 注意力图
17 a4 = self.tanh(a4) # 防止早期训练时梯度爆炸,tanh激活层作为正则化层
18 a4 = torch.abs(a4) # 绝对值 a4 = |D(ft)|
19 a4_big = a4.expand(x4.size()) # 即a',为了匹配目标域的维度,实现注意力图和目标域按元素相乘
20 x4_a4 = a4_big*x4 + x4 # ft'=ft+ft*a'
21 x5 = self.layer5(x4_a4)
22 out = self.layer6(x5)
23 # print('D[0]',D[0])
24 # print('domain:', domain)
25 # print('x4.size:', x4.size()) # x4.size: torch.Size([1, 2048, 65, 129])
26 # print('out.size:', out.size()) # out.size: torch.Size([1, 19, 65, 129])
27 return x4, out
判别器:(FCDiscriminator输入通道2048,而OutspaceDiscriminator输入通道是19)
model_D = nn.ModuleList([FCDiscriminator(num_classes=num_class_list[i]).train().to(
device) if i < 1 else OutspaceDiscriminator(num_classes=num_class_list[i]).train().to(device) for i in
range(2)])
class FCDiscriminator(nn.Module):
def __init__(self, num_classes, ndf = 64):
# print('num_classes:', num_classes) num_classes: 2048
super(FCDiscriminator, self).__init__()
self.conv1 = nn.Conv2d(num_classes, num_classes//2, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(num_classes//2, num_classes//4, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(num_classes//4, num_classes//8, kernel_size=3, stride=1, padding=1)
self.classifier = nn.Conv2d(num_classes//8, 1, kernel_size=3, stride=1, padding=1)
self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
#self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear')
#self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.conv1(x)
x = self.leaky_relu(x)
x = self.conv2(x)
x = self.leaky_relu(x)
x = self.conv3(x)
x = self.leaky_relu(x)
x = self.classifier(x)
#x = self.up_sample(x)
#x = self.sigmoid(x)
return x
class OutspaceDiscriminator(nn.Module):
def __init__(self, num_classes, ndf = 64):
super(OutspaceDiscriminator, self).__init__()
self.conv1 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1)
self.classifier = nn.Conv2d(ndf*8, 1, kernel_size=4, stride=2, padding=1) # 变成通道数为1
self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
#self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear')
#self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.conv1(x)
x = self.leaky_relu(x)
x = self.conv2(x)
x = self.leaky_relu(x)
x = self.conv3(x)
x = self.leaky_relu(x)
x = self.conv4(x)
x = self.leaky_relu(x)
x = self.classifier(x)
#x = self.up_sample(x)
#x = self.sigmoid(x)
return x
 

1 # D[0](x4):
2 # tensor([[[[0.0710, 0.1864, 0.2138, ..., 0.2505, 0.1997, 0.1675],
3 # [0.0946, 0.2139, 0.2130, ..., 0.2979, 0.2266, 0.1543],
4 # [0.1402, 0.2508, 0.2545, ..., 0.3649, 0.3104, 0.1574],
5 # ...,
6 # [0.1940, 0.3481, 0.3824, ..., 0.3082, 0.2303, 0.1237],
7 # [0.1855, 0.2981, 0.3047, ..., 0.2617, 0.1878, 0.0770],
8 # [0.0597, 0.1503, 0.1717, ..., 0.1718, 0.1432, 0.0634]]]],
9 # device='cuda:0', grad_fn= < AddBackward0 >)

1     # D[0]:
2 # FCDiscriminator(
3 # (conv1): Conv2d(2048, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
4 # (conv2): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
5 # (conv3): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
6 # (classifier): Conv2d(256, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
7 # (leaky_relu): LeakyReLU(negative_slope=0.2, inplace=True)
8 # )
1     # model_D[1]: OutspaceDiscriminator(
2 # (conv1): Conv2d(19, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
3 # (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
4 # (conv3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
5 # (conv4): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
6 # (classifier): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
7 # (leaky_relu): LeakyReLU(negative_slope=0.2, inplace=True)
8 # )

#######开始训练######

train S

# train with source
feat_source, pred_source = model(images, model_D, 'source')
# ResNet返回两层输出结果(特征图),x4和out   model_D判别器模型,用来对resnet的x4层作输出,得到注意力图
# print('feat_source',feat_source) feat_source=x4 倒数第二层(输出通道数:2048) ; pred_source=out 最后一层(输出通道数:19(类别数)) 如果是目标域图片,那么,out代表加权后的特征图输出。
pred_source = interp(pred_source) # 源域特征图 参与分割损失计算
loss_seg = seg_loss(pred_source, labels)
loss_seg.backward()

# train with target

     feat_target, pred_target = model(images, model_D, 'target')
# print('feat_target.size, pred_target.size:', feat_target.size(), pred_target.size())
# feat_target.size, pred_target.size: torch.Size([1, 2048, 65, 129]) torch.Size([1, 19, 65, 129])
pred_target = interp_target(pred_target)
loss_adv = 0
D_out = model_D[0](feat_target) # 对倒数第二层的T域特征图打分
loss_adv += bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).to(device))
D_out = model_D[1](F.softmax(pred_target, dim=1)) # 先把最后一层特征图变成概率图,再对概率图打分
# print('model_D[1]:', model_D[1])
loss_adv += bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).to(device))
loss_adv = loss_adv * 0.01
loss_adv.backward()
optimizer.step()

train D


# train with source

     loss_D_source = 0
D_out_source = model_D[0](feat_source.detach())
loss_D_source += bce_loss(D_out_source,
torch.FloatTensor(D_out_source.data.size()).fill_(source_label).to(device))
D_out_source = model_D[1](F.softmax(pred_source.detach(), dim=1))
loss_D_source += bce_loss(D_out_source,
torch.FloatTensor(D_out_source.data.size()).fill_(source_label).to(device))
loss_D_source.backward()

# train with target

        loss_D_target = 0
D_out_target = model_D[0](feat_target.detach())
loss_D_target += bce_loss(D_out_target,
torch.FloatTensor(D_out_target.data.size()).fill_(target_label).to(device))
D_out_target = model_D[1](F.softmax(pred_target.detach(), dim=1))
loss_D_target += bce_loss(D_out_target,
torch.FloatTensor(D_out_target.data.size()).fill_(target_label).to(device))
loss_D_target.backward()
optimizer_D.step()

ST部分

 


images.size: torch.Size([1, 3, 512, 1024])

DAST 代码分析的更多相关文章

  1. Android代码分析工具lint学习

    1 lint简介 1.1 概述 lint是随Android SDK自带的一个静态代码分析工具.它用来对Android工程的源文件进行检查,找出在正确性.安全.性能.可使用性.可访问性及国际化等方面可能 ...

  2. pmd静态代码分析

    在正式进入测试之前,进行一定的静态代码分析及code review对代码质量及系统提高是有帮助的,以上为数据证明 Pmd 它是一个基于静态规则集的Java源码分析器,它可以识别出潜在的如下问题:– 可 ...

  3. [Asp.net 5] DependencyInjection项目代码分析-目录

    微软DI文章系列如下所示: [Asp.net 5] DependencyInjection项目代码分析 [Asp.net 5] DependencyInjection项目代码分析2-Autofac [ ...

  4. [Asp.net 5] DependencyInjection项目代码分析4-微软的实现(5)(IEnumerable<>补充)

    Asp.net 5的依赖注入注入系列可以参考链接: [Asp.net 5] DependencyInjection项目代码分析-目录 我们在之前讲微软的实现时,对于OpenIEnumerableSer ...

  5. 完整全面的Java资源库(包括构建、操作、代码分析、编译器、数据库、社区等等)

    构建 这里搜集了用来构建应用程序的工具. Apache Maven:Maven使用声明进行构建并进行依赖管理,偏向于使用约定而不是配置进行构建.Maven优于Apache Ant.后者采用了一种过程化 ...

  6. STM32启动代码分析 IAR 比较好

    stm32启动代码分析 (2012-06-12 09:43:31) 转载▼     最近开始使用ST的stm32w108芯片(也是一款zigbee芯片).开始看他的启动代码看的晕晕呼呼呼的. 还好在c ...

  7. 常用 Java 静态代码分析工具的分析与比较

    常用 Java 静态代码分析工具的分析与比较 简介: 本文首先介绍了静态代码分析的基 本概念及主要技术,随后分别介绍了现有 4 种主流 Java 静态代码分析工具 (Checkstyle,FindBu ...

  8. SonarQube-5.6.3 代码分析平台搭建使用

    python代码分析 官网主页: http://docs.sonarqube.org/display/PLUG/Python+Plugin Windows下安装使用: 快速使用: 1.下载jdk ht ...

  9. angular代码分析之异常日志设计

    angular代码分析之异常日志设计 错误异常是面向对象开发中的记录提示程序执行问题的一种重要机制,在程序执行发生问题的条件下,异常会在中断程序执行,同时会沿着代码的执行路径一步一步的向上抛出异常,最 ...

  10. [Asp.net 5] DependencyInjection项目代码分析4-微软的实现(2)

    在 DependencyInjection项目代码分析4-微软的实现(1)中介绍了“ServiceTable”.“ServiceEntry”.“IGenericService”.“IService”. ...

随机推荐

  1. 12组-Alpha冲刺-总结

    组长博客链接 https://www.cnblogs.com/147258369k/p/15573118.html 一.基本情况 1.1 现场答辩总结 PPT制作方面略显粗糙,对于产品描述的具体内容不 ...

  2. nginx 反向代理 (websocket)后报 - 400 bad request

    nginx的反向代理. nginx.conf中的配置如下: location / {        proxy_http_version                1.1;         pro ...

  3. 技术前沿:ISP芯片终极进化——VP芯片(AI视觉处理器)

    1.计算机视觉的定义 广义与狭义 从广义上说,计算机视觉就是"赋予机器自然视觉能力"的学科.自然视觉能力,就是指生物视觉系统体现的视觉能力. 从狭义上讲,计算机视觉是以图像(视频) ...

  4. usb 2.0 high speed resetting signaling.

  5. linux升级系统内核版导致死锁

    如上图片,官方说明为linux内核版本过低,存在系统bug,具体说明如下: https://baijiahao.baidu.com/s?id=1652492237858209875&wfr=s ...

  6. 向日葵RCE复现

    CNVD-2022-10270/CNVD-2022-03672 向日葵RCE复现 前言 向日葵是一款免费的集远程控制电脑手机.远程桌面连接.远程开机.远程管理.支持内网穿透的一体化远程控制管理工具软件 ...

  7. Win10系统所有文件夹被设为只读,取消之后再次打开属性依然只读,怎么解决?

    安装完Nodejs之后发现npm info vue指令没有权限运行: C:\Users\JC>npm info vue npm ERR! code EPERM npm ERR! syscall ...

  8. 关于给widget添加属性

    在django中,我们通过修改Form/ModelForm的初始化函数__init__修改表单的显示样式,其中修改widget的属性操作和字典操作一致. 1.给widget添加属性 说明:这是在不影响 ...

  9. Fortran处理无符号整型unsigned integer

    背景: 计算机是以一串二进制数,用约定的表示方式来存储数据的.约定表示方式的不同,造成了可以表示数的范围不同.其中,对于整数类型数据的表示,有unsigned integer(无符号整型)和signe ...

  10. powergui模块基本设置

    Powergui模块可以显示系统稳定状态的电流和电压及电路(电感电流和电容电压)所有的状态变量值. 尤其是电力电子仿真中需要加入powergui模块,否则会报错. simulink仿真用到simpow ...