Citation

Al-Molegi A , Martínez-Ballesté, Antoni, Jabreel M . Move, Attend and Predict: An Attention-based Neural Model for People’s Movement Prediction[J]. Pattern Recognition Letters, 2018:S016786551830182X.

概览

本文与之前所阅读的几篇轨迹预测文章不同,从实质上说,前面的轨迹预测是回归问题,而本文则是一个分类问题,其采纳循环神经网络对小场景中轨迹预测的提升,将其运用于更大时间跨度(最小为小时,由GPS、打卡机等设备采集)的地点变换预测上。具体来说,定义Move, Attend and Predict (MAP)模型,模型的输入由(二维地址, 时间戳)构成,输出则为根据以往地址信息所预测的下一个地址,模型由RNN编码器、注意力模型和预测模型三部分组成,总体来说结构比较简单,但其在实验评估部分的方法留留给我了一些启示,稍后将在文章中给出。

HighLights

  1. 时间信息与注意力机制:以往相关研究如STF-RNN网络将(地点独热值,时间点独热值)元组一并嵌入作为循环神经网络的输入。而MAP模型则采用另一种思路,引入注意力模型,使用RNN单独处理二维地址信息并保存输出,时间戳信息则以计算注意力权重并生成注意力向量的身份参与到模型中。
  2. 神经网络的可解释性研究:文章在实验部分对数据时间戳的定义、注意力机制的有效性、嵌入维度进行了细致地探讨,通过可视化方式较为直观地得出了:
    • 注意力机制令模型更关注最近时间的信息
    • 时间戳以小时效果最佳且应设为离开地点的时刻
    • 嵌入维度对模型提升瓶颈与24小时制有关
  3. 离散化衡量指标:设备限制和数据处理方法使得模型的地点信息是离散且有限的,因此模型的评估同之前行人轨迹预测中ADE和FDE连续化指标不同,分别为准确率,召回率和F1-Score。

Future Work

  1. 尝试使用更高级的RNN单元如GRN、LSTM。
  2. 纳入更多信息考虑因素,如行人交互和地点之间的距离。
  3. 克服模型无法预测未知地点的问题,引入未知地点的概率预测。

模型

MAP模型由三部分组成:地点信息模型(左部灰色区)、注意力模型(右部灰色区)、分类器(上)

规范

对于行人,给定\(w\)对\(p_i=(l_i,t_i)\ \ (1 \le i \le w)\)元组-分别表示地点和时间戳,表示该行人过去的轨迹序列。

  • 地点处在有限集合中,因此已经编码为独热编码(维度N)。
  • 时间戳以小时为单位划分,其代表的均是离开与其对应地点的时刻,也采用的是独热编码(维度M)。

模型的目的是基于这\(w\)对信息,推测行人下一步的地点:\(P(l_{i+1}|p_i,...,p_1)\)

地点信息模型

地点信息模型是基本的RNN结构,其首先将\(N\)维的独热值地点信息经过嵌入矩阵\(Le\)生成\(d_l\)维向量,而后作为RNN的每一步的输入参与编码,最后一次的RNN输出(维度为\(d_r\))作为summary vector参与注意力模型运算和分类器:

\[le_i = l_i \cdot Le, \ \ r_i = RNN(le_i;W_{rnn})\]

\[r_i \in \R^{d_r}\]

[注意]:模型下标编号是倒序的,以\(i\)为结尾,一直到\(i-w+1\),因此\(r_i\)是RNN最后一个输出。

注意力模型

Question

请仔细参考结构图明确MAP的注意力模型中,“注意的对象到底是谁?“。

应该为时间戳经过嵌入后形成的W个嵌入向量,而不是RNN模型输出,这点需要和带注意力机制的RNN模型区分开。

用Attention Mechanism机制中的三个指标(Query,Key和Value)来具体刻画此模型的注意力机制:

  • Query:来自RNN网络的Summary Vector \(r_i\)
  • Key = Value=\(\Omega\):时间戳独热值经嵌入处理后的\(w\)个向量, \(te_i = t_i \cdot Te, \ \ \{te_{i-w},..,i\} = \Omega\)

注意力权重由Query和Key点乘并归一化得到:

\[\alpha =softmax(\Omega \cdot r_i) \ \ \alpha \in \R^{w \times 1}\]

注意力向量由注意力权重和Value进行element-wise的乘法运算:

\[\eta = \alpha * \Omega \ \ \eta \in \R^{w \times d_r}\]

分类器

分类器实质就是综合RNN编码器和注意力模型两部分模型的信息,进行简单的线性变换,并用softmax压缩为概率对于每个地点的预测概率。这里的综合函数文章给出了两种参考,一种是拼接,另一种是相加,\(W_F\)根据不同策略维度需要有所变化。

\[\hat y = softmax(F(r_i, \eta) \cdot W_F+b)\]

Optimize

模型最后是优化方法和损失函数,优化方法采用的是ADADELTA,损失函数则直接基于softmax输出的连续概率分布计算(否则离散化后无法求导进行反向传播),这与评估时需要离散的方式是不同的。

\[Loss = - \Sigma^n_{i=1}y_i \cdot log(\hat y_i)\]

[注意]:上述公式只是一个人的损失,i遍历的是有限n个地点,\(y_i\)只有0和1的取值,\(\hat y_i\)是\(\hat y\)中的具体数值。

模型评估

数据集

MAP模型的数据集相比之前轨迹预测的数据集更宏观,时间跨度以小时为最小单位,并且地点也是离散和有限的。

  • Geolife:GPS设备记录的原始信息,作者在将Log信息转换为轨迹信息时,首先使用算法侦测一些地区,而后用DBSCAN聚类算法(\(\varepsilon = 100, \ minPts=3\))形成了离散固定的地点,行人的位置信息只能由这些固定地点所表示。
  • Gowalla:处在固定地点的打卡机所记录的时间戳数据。

量化评价

  1. 评价指标:由于位置离散有限,因此预测是采用softmax计算概率分布并采样,这使得MAP模型与之前预测连续分布的模型评价指标(ADE,FDE)不同。

    • 取前N个最有可能地点最为预测地点集合\(L_{N,u}\), 真实的地点集合\(L_u\)。
    • 准确率 - 预测地点集合中有多少真实命中的:\(Precision@N={1 \over |U|}\Sigma _{u \in U}{|L_u \bigcap P_{N,u}| \over |P_{N,u}|}\)。
    • 召回率 - 真实地点集合中有多少被预测到的:\(Recall@N = {1 \over |U|}{\Sigma_{u \in U}}{|L_u \bigcap P_{N,u}| \over |L_u|}\)
    • \(F1-Score@N = 2 \times {Precision@N \times Recall@N \over Precision@N + Recall@N}\)
  2. 简要结论(具体请参见原文)
    • 基础RNN模型提升的预测能力,若考虑加入时间因素,预测能力进一步提升。
    • 数据嵌入(embedding)对模型预测能力提升明显,这是因为嵌入层能模型很好提取潜在的语义信息
    • MAP表现最好……

神经网络可解释化的研究

文章在神经网络的可解释方面做了很深的研究,进一步加强了数据定义与模型设计的合理性。

注意力机制

首先,文章中"被注意"的对象是时间独热值经嵌入层得到的\(w\)个时间嵌入向量,文章探讨在计算注意力权重\(\alpha\)时考虑因素的不同对注意力产生的影响。设置已知轨迹长度\(w=2\),两个对比分别是\(\alpha_1 = softamx(g(r_i, \Omega)) \ \ VS \ \ \alpha_2=softmax(g(r_i))\),最后探讨\(\alpha\)的权重分布特点,下图中case1是考虑空间+时间,case2只考虑空间。

结论:将时间因素$\Omega \(纳入考虑得到的\)\alpha$符合人为的认知结果 - 更关注距离当前时间点更近的时间嵌入向量。

隐藏层(时间嵌入)维度最优值

模型中为保证维度正确性,隐藏层维度和时间嵌入向量维度保持一致,根据实验结果,在\(d_r=24\)左右达到峰值,或高或低都导致预测能力下降。这恰好说明模型对24小时制的学习效果,过高或过低维度形成的时间段都将导致与时间戳定义(小时制)的不吻合。

时间戳定义

时间戳的定义由两方面问题,一是单位的定义(小时,时辰,天,月……),二是选择哪个时刻与对应地点相对应,经过对比,最终得出:

  • 小时制度效果更好。
  • 选择离开该地点的时刻作为时间戳效果更好,印证了“离开时刻对预测下一步地点最具影响力”的人为认知。

文献阅读报告 - Move, Attend and Predict的更多相关文章

  1. 文献阅读报告 - Social BiGAT + Cycle GAN

    原文文献 Social BiGAT : Kosaraju V, Sadeghian A, Martín-Martín R, et al. Social-BiGAT: Multimodal Trajec ...

  2. 文献阅读报告 - Social Ways: Learning Multi-Modal Distributions of Pedestrian Trajectories with GANs

    文献引用 Amirian J, Hayet J B, Pettre J. Social Ways: Learning Multi-Modal Distributions of Pedestrian T ...

  3. 文献阅读报告 - Situation-Aware Pedestrian Trajectory Prediction with Spatio-Temporal Attention Model

    目录 概览 描述:模型基于LSTM神经网络提出新型的Spatio-Temporal Graph(时空图),旨在实现在拥挤的环境下,通过将行人-行人,行人-静态物品两类交互纳入考虑,对行人的轨迹做出预测 ...

  4. 文献阅读报告 - 3DOF Pedestrian Trajectory Prediction

    文献 Sun L , Yan Z , Mellado S M , et al. 3DOF Pedestrian Trajectory Prediction Learned from Long-Term ...

  5. 文献阅读报告 - Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks

    paper:Gupta A , Johnson J , Fei-Fei L , et al. Social GAN: Socially Acceptable Trajectories with Gen ...

  6. 文献阅读报告 - Social LSTM:Human Trajectory Prediction in Crowded Spaces

    概览 简述 文献所提出的模型旨在解决交通中行人的轨迹预测(pedestrian trajectory prediction)问题,特别是在拥挤环境中--人与人交互(interaction)行为常有发生 ...

  7. 文献阅读报告 - Pedestrian Trajectory Prediction With Learning-based Approaches A Comparative Study

    概述 本文献是一篇文献综述,以自动驾驶载具对外围物体行动轨迹的预测为切入点,介绍了基于运动学(kinematics-based)和基于机器学习(learning-based)的两大类预测方法. 并选择 ...

  8. 文献阅读报告 - Context-Based Cyclist Path Prediction using RNN

    原文引用 Pool, Ewoud & Kooij, Julian & Gavrila, Dariu. (2019). Context-based cyclist path predic ...

  9. 文献阅读笔记——group sparsity and geometry constrained dictionary

    周五实验室有同学报告了ICCV2013的一篇论文group sparsity and geometry constrained dictionary learning for action recog ...

随机推荐

  1. JAVA 发送各种邮箱邮件 javamail

    QQ邮箱 /** * 单条发送 * @param mail 邮件对象,包含发送人.邮件主题.邮件内容 * @param recipients 收件人 * @throws AddressExceptio ...

  2. KAZE特征和各向异性扩散滤波

    kaze feature: http://www.doc88.com/p-6911376909693.html 各向异性扩散滤波  Scale-space and edge detection usi ...

  3. 一步步教你整合SSM框架(Spring MVC+Spring+MyBatis)详细教程重要

    前言 SSM(Spring+SpringMVC+Mybatis)是目前较为主流的企业级架构方案,不知道大家有没有留意,在我们看招聘信息的时候,经常会看到这一点,需要具备SSH框架的技能:而且在大部分教 ...

  4. HTML中用自定义字体实现小图标icon(不是原作, 只是一个研究笔记)

    最近在做一个项目时, 研究了一下新浪微博的前端, 看到首页中那个图标了吗, 以前看到这类效果的第一反应就是用一个gif之类的图标做出来!! 但在研究的过程, 发现了一个小技巧, 注意那个em标签中的文 ...

  5. Linux服务器命令大全

    快捷提示键: table 查看文件夹:  ls , ls –all ,ls –l,ll 进入某个文件夹: cd usr/local 回到root 目录 : cd /root/ 回到根目录:cd / 回 ...

  6. unity基础开发----Unity获取PC,Ios系统的mac地址等信息

    在软件开发中可以会用到mac地址作为,设备的唯一标示,我们也可以通过unity获取,经测试pc,ios都可以但是安卓没有获取到. 代码如下: using UnityEngine; using Syst ...

  7. docker 构建php-fpm IMages(dockerfile)

    好久没写blog   做什么?   复习nginx  zabbix  docker-compos mariadb  学习 jenkins ansible ELK  k8s (kubeadm)  好了也 ...

  8. UVA - 11354 Bond(最小生成树+LCA+瓶颈路)

    题意:N个点,M条路,每条路的危险度为路上各段中最大的危险度.多组询问,点s到点t的所有路径中最小的危险度. 分析: 1.首先建个最小生成树,则s到t的路径一定是危险度最小的. 原因:建最小生成树的最 ...

  9. 从ofo牵手理财平台看,用户隐私数据的使用有底线吗?

    智慧生活的到来既是社会变迁的拐点,又不可避免地带来一种挥之不去的焦虑.这种焦虑的由来,是因个人隐私数据在智慧生活下变成一种"黑暗财富".随着相关数据挖掘.收集.分析技术的成熟,人们 ...

  10. excel提取数字

    部分提取,那么就用=-LOOKUP(,-MID(A1,MIN(FIND({0;1;2;3;4;5;6;7;8;9},A1&1234567890)),ROW($1:$1024))) ------ ...