近日在阅读Social GAN文献的实验代码,加深对模型的理解,发现源代码的工程化很强,也比较适合构建实验模型的学习,故细致阅读。下文是笔者阅读中一些要点总结,有关于pytorch,也有关于模型自身的。

GPU -> CPU

SGAN的实验代码在工程化方面考虑比较充分,考虑到了在CPU和GPU两种平台上模型的运行。原生平台是GPU,若要切换为CPU,需要做如下改动(目前只改动了训练过程所需的,测试评估还未进行,但估计类似):

  1. args.use_gpu需要置为0,以保证int_dtypefloat_dtype不是cuda。
  2. 检索cuda(),可以发现在model.py还有些残缺未考虑的cuda定义,使用torch.cuda.is_available()判断是否GPU可使用,只有可行采用cuda()定义:
x = xxx()
if torch.cuda.is_available():
x = x.cuda()

池化层实现细节

Social GAN相较于Social LSTM提出了新的池化模型以满足不同行人轨迹间信息共享与相互作用,具体有以下几个方面的变动:

  1. Social GAN的池化频率为一次,只在利用已知轨迹编码后进行一次池化。(代码中一个额外选项是在预测的每一步都进行池化)
  2. 池化范围为全局而不是固定的范围区间,代码使用max pooling的手段使得在场景人数不确定的情况下可以保持数据维度固定。
  3. 池化输入数据由两方面组成:LSTMs的隐藏状态+最后位置的相对信息

而在代码实现时,计算相对位置信息时显得比较巧妙,例如在同场景的行人位置信息,代码通过两次不同的repeat策略将原有N个人的位置信息重复N次,从而形成了[P0, P0, P0, ...] [P1, P1, P1, ...] ... 和 [P0, P1, P2, ...] [P0, P1, P2, ...] ..两个矩阵,通过矩阵相减即可得到一个N*N行的矩阵,第\(i\)行是第\(i \% N\)个人相对于第\(i / N\)个人的相对位置。

	curr_hidden = h_states.view(-1, self.h_dim)[start:end]
curr_end_pos = end_pos[start:end] # Repeat -> [H1, H2, H3, ...][H1, H2, H3, ...]...
curr_hidden_1 = curr_hidden.repeat(num_ped, 1)
# Repeat position -> [P1, P2, P3, ...][P1, P2, P3, ...]...
curr_end_pos_1 = curr_end_pos.repeat(num_ped, 1)
# Repeat position -> [P1, P1, P1, ...][P2, P2, P2, ...]...
curr_end_pos_2 = self.repeat(curr_end_pos, num_ped)
# 得到行人的end_pos间的相对关系,并交给感知机去具体处理。
# 每个行人与其他行人的相对位置关系由num_ped项,合计有num_ped**2项。
curr_rel_pos = curr_end_pos_1 - curr_end_pos_2
curr_rel_embedding = self.spatial_embedding(curr_rel_pos) # 拼接H_i和处理过的pos,放入多层感知机,最后经过maxPooling。
mlp_h_input = torch.cat([curr_rel_embedding, curr_hidden_1], dim=1)
curr_pool_h = self.mlp_pre_pool(mlp_h_input)

DataLoader相关

安利一个知乎,上面对使用Pytorch实现dataLoader解释得很细致

https://zhuanlan.zhihu.com/p/30385675

dataLoader迭代器的数据格式

Dataset继承而来的TrajectoryDataSet__get_item__进行了重写,以方便dataLoader使用并整合,每次函数返回的是一个列表:

out = [
self.obs_traj[start:end, :], self.pred_traj[start:end, :],
self.obs_traj_rel[start:end, :], self.pred_traj_rel[start:end, :],
self.non_linear_ped[start:end], self.loss_mask[start:end, :]
]
return out

列表中有6个元素,以obs_traj为例,其大小为[N][2][seq_len],但是在使用dataLoader进行迭代时出现了这种形式,不仅一个batch中解压得到的变为7个,而且obs_traj的大小变为[seq_len][batch][2],,顺序发生了变化.

batch = [tensor.cuda() for tensor in batch]
(obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,loss_mask, seq_start_end) = batch

Solution

  1. 问题主要是忽视了DataLoadercollate_fn函数的作用,这个函数是在trajectory.py中自定义的函数,主要作用时当dataLoader收集到batch_sizeitem后形成一个列表,而后交由自定义的collate_fn做预处理,处理后的数据就会被输出为batch
  2. seq_collate解答了数据格式的两个疑问,包括使用permutecat函数。

从dataLoader获取的batch数据的概念辨析

Solution

  1. batch != batch_size

    1. 模型注释中有多处使用batch来表示张量格式,一个batch的数据常常有batch_size行,但在该模型中不成立。
    2. 严格来说,一个batch中有batch_sizeitem,但一个item可以用多行表示,这就是该模型的数据特点,其在一个batch中额外新增了seq_start_end列表(len(seq_start_end) == batch_size),使用该列表即可抽取出一个item

    \[batch = \Sigma_{i=0}^{batch\_size-1}N_i (N_i \ge min\_peds)
    \]

    \(N_i\)表示一个场景下的行人个数。

  2. 一个batch中有多场景的行人轨迹数据

    1. LSTM编码和译码:每个轨迹都是独立的,此时可以整个batch一起处理
    2. 池化:设计同一场景下各行人序列数据交互,需要使用seq_start_end划分场景分别计算。

Social GAN代码要点记录的更多相关文章

  1. iBatis & myBatis & Hibernate 要点记录

    iBatis & myBatis & Hibernate 要点记录 这三个是当前常用三大持久层框架,对其各自要点简要记录,并对其异同点进行简单比较. 1. iBatis iBatis主 ...

  2. JAVA 中LinkedHashMap要点记录

    JAVA 中LinkedHashMap要点记录 构造函数中可能出现的几个参数说明如下: 1.initialCapacity 初始容量大小,使用无参构造方法时,此值默认是16 2.loadFactor ...

  3. 文献阅读报告 - Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks

    paper:Gupta A , Johnson J , Fei-Fei L , et al. Social GAN: Socially Acceptable Trajectories with Gen ...

  4. Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks

    Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks 2019-06-01 09:52:4 ...

  5. 编写高质量JavaScript代码的基本要点记录

    原文:深入理解JavaScript系列(1):编写高质量JavaScript代码的基本要点 1.最小全局变量(Minimizing Globals)的重要性 JavaScript通过函数管理作用域.在 ...

  6. python学习第一课要点记录

    写在要点之前的一段话,留给将来的自己:第一次参加编程的培训班,很兴奋很激动,之前都是自己在网上找免费的视频来看,然后跟着写一些课程中的代码,都是照着模子写,没有自己过多的思考.感觉这样学不好,除了多写 ...

  7. Android开发入门要点记录:四大组件

    cocos2dx跨平台开发中需要了解android开发,昨天快速的浏览了一本Android开发入门教程,因为之前也似懂非懂的写过Activity,Intent,XML文件,还有里面许多控件甚至编程思想 ...

  8. Unity Scripting Tutorials 要点记录

    (搬运自我在SegmentFault的博客) 这几天通过Unity官网的Unity Scripting Tutorials的视频学习Unity脚本,观看的过程中做了记录.现在,整理了一下笔记,供自己以 ...

  9. web基础要点记录

    最近公司项目做完了,不怎么忙,翻看了一些基础的资料,文章.就做了个简单的记录. 1.Chrome 中文界面下默认会将小于 12px 的文本强制按照 12px 显示, 可通过加入 CSS 属性  -we ...

随机推荐

  1. Problem A: Assembly Required K路归并

    Problem A: Assembly Required Princess Lucy broke her old reading lamp, and needs a new one. The cast ...

  2. zabbix监控memcached服务

    zabbix监控memcached服务 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.安装并配置memcached服务 1>.使用yum方式安装memcached [ro ...

  3. QT5安装

    Windows+Qt5.3.1+VS2013安装教程 https://blog.csdn.net/two_ye/article/details/96109876 (已成功)windows下,VS201 ...

  4. C#遍历DataSet

    ] foreach (DataRow dr in dt.Rows) ///遍历所有的行 foreach (DataColumn dc in dt.Columns) //遍历所有的列 Console.W ...

  5. stm32串口收发导致的死机

    stm32串口收发导致的死机 很久以前有偶尔遇到过串口死机的情况,那是当时的我写出来的代码自己都觉得有问题,也就没注意.用了stm32做项目以后也就没遇到过了,今天做了个高压测试,每5ms定时发送一次 ...

  6. CentOS 7安装/卸载Redis,配置service服务管理

    Redis简介 Redis功能简介 Redis 是一个开源(BSD许可)的,内存中的数据结构存储系统,它可以用作数据库.缓存和消息中间件. 相比于传统的关系型数据库,Redis的存储方式是key-va ...

  7. C++面试常见问题——14内存管理

    内存管理 内存管理由三种方式: 自动存储 静态存储 动态存储 自动存储 对于函数的形参.函数内部变量.和结构体变量等,编译器在函数运行过程中在栈中自动对其分配内存,调用结束后对其进行销毁.变量的声明周 ...

  8. 011、Java中将范围大的数据类型变为范围小的数据类型

    01.代码如下 package TIANPAN; /** * 此处为文档注释 * * @author 田攀 微信382477247 */ public class TestDemo { public ...

  9. 001.Oracle数据库 , 查询日期在两者之间

    /*Oracle数据库查询日期在两者之间*/ SELECT OCCUR_DATE FROM LM_FAULT WHERE ( ( OCCUR_DATE >= to_date( '2017-05- ...

  10. 015.Delphi插件之QPlugins,FMX插件窗口

    内嵌FMX的插件窗口,效果还是很可以的.退出时,会报错,很诡异啊. 主窗口代码如下 unit Frm_Main; interface uses Winapi.Windows, Winapi.Messa ...