1 Transformer 是一种 seq2seq 建模方法

(著名的 GPT 的全称是 Generative Pre-trained Transformer)

学习 Transformer:

seq2seq 的输入输出:

  • 在 nlp 领域貌似是 word embedding,然后再使用 word2vec 之类得到单词(?)

Attention 与 Transformer:

  • attention:

    • key query value:key 用来提取关键信息、query 用来提取查询、value 用来提取值。它们具有矩阵形式,用 k q v 矩阵去乘输入的 vector,得到 k q v 的 vector。
    • k q v 举例:希望投票选举,query - 评委的重要程度、key - 评委的职称(?)、value - 评委的投票结果,最后按照 (query × key^T) × value 的形式,对投票结果进行加权计算。 multi-head-attention 就是使用多组 k q v,可能表示我们希望关注多个方面。
  • encoder:
    • 一个 encoder 块包含一个 attention + 一个 feed forward 层(大概就是全连接层)。
    • 我们使用 k q v 的 attention 模块,一下对所有 token 的矩阵(维度 num of tokens × embedding size)得到一个 latent z(维度 num of tokens × latent size)。
    • 然后对于每个原句子中的 token,各过各的全连接层(feed forward)(?)最后得到一个 维度 num of tokens × embedding size 的矩阵。
    • 残差连接(Res):将输入和 多头注意力层 或全连接神经网络的输出 相加,再传递给下一层,避免梯度递减的问题。
  • decoder:
    • 一个 decoder 块包含两个 attention + 一个 feed forward 层。
    • attention 1 用来处理自己输出的信息,因此它在说第 n 个单词之前,只能以自己说出的前 n-1 个单词作为输入,使用一个掩码(?)来实现:掩码多头自注意力(Masked-Multi-head self attention)。
    • attention 2 用来处理 encoder 给出的 num of tokens × embedding size 的 embedding,attention 1 的输出也是其输入的一部分。
  • 这样,看图应该就能看懂了。

2 建模 RL 的 sequence

我们的 sequence:{return-to-go, state, action, return-to-go, ...}

  • 形式类似于 \(\{s_t,a_t,r_t,s_{t+1},\cdots\}\) 。
  • return-to-go: \(\hat R_t=\sum_{t'=t}^Tr_{t'}\) ,是从此刻 t 到 episode 结束的,in-discounted reward 的加和。
  • 感觉 return-to-go 类似于 HER 的预期目标,比较 hindsight。

3 如何训练 DT

对 sequence {s, a, R, s, a, R, ...} 进行处理:

  • 对每个 modality(s a R),都学习了将它们转换为 embedding 的线性层。
  • 对于具有视觉输入的环境,状态被输入到卷积编码器而不是线性层中。
  • 此外,每个时间步的 embedding 都会被学习并添加到每个 token 中 —— 这与 transformer 使用的 positional embedding(三角函数?)不同,因为一个时间步对应于三个 token。

(搬图,搬运文字说明)

  • 一条轨迹按照 s a R 顺序排列好后,每个元素都是图 1 下部的一个小圆圈,类似于 NLP 中的一个个单词。
  • 然后,每个元素经过一个 mlp 做 embedding 后,再加上 position encoding,就得到了 tokens,也就是图 1 下部的一个个五颜六色的小长方块。

训练:

  • 使用 offline trajectory 的 dataset(D4RL 之类)。
  • 从离线轨迹数据集中,抽取 sequence 长度为 K 的 minibatch。
  • 训练:对 input token \(s_t\) 的那个 prediction head,再加一个 mlp 来预测 \(a_t\) (上图上部输出 \(a_t\) 的橙色方块)。
    • 训 action 时,对离散动作使用 cross-entropy loss,连续则使用 MSE。
    • DT 每隔三个 token 才 decode 一个,因为作为 policy 只需要输出 action。但其实,output tokens 由对应 return-to-go、state、action 的 token 组成,所以自然只留下对应 action 的 tokens(?)
    • 发现,去预测 state 或 return-to-go 并不能提高性能,尽管在 DT 框架里,很容易这么做。
  • 上述训练部分是想让 DT 学会,在某个特定状态 s 下,达到 return-to-go R,所需要做的动作 a。
  • 详见 Algorithm 1 伪代码,感觉写的很清楚。

4 如何部署 DT policy / 如何 inference

  • inference 过程就是,首先提出一个 target return(我们希望 agent 在一个 episode 里能达到的 return),作为初始的 return-to-go,然后 DT 按照训练过程中学到的 如何达成 return-to-go 的方法,选择 action。
  • 每走一步,就将上一步的 return-to-go 减去这一步的 reward,得到下一个 return-to-go,从而不断地更新我们期望 DT 达到的 return 目标,同时 DT 根据我们的目标,不断选择 action。
  • 详见 Algorithm 1 伪代码。
  • evaluate 时,只保留 length = K 的 context,对应于前面训练时 sequence length = K。
    • 通常认为,当使用 frame stacking(Atari 的帧堆叠)时,K = 1 已经 MDP,足以用于 RL 算法。然而,当 K = 1 时,Decision Transformer 的性能明显更差,这表明过去的信息对 Atari 游戏有用(非 MDP?)。(具体实验中,Atari 的 K = 30 50,MuJoco 的 K = 5 20)
    • 一个假设是,当我们表示 一些策略的分布时(例如序列建模),上下文允许 transformer 识别,哪个策略生成了该动作,从而实现更好的学习和 / 或改进训练动态。

5 技术细节

  • 训练的一些超参数,encoder / decoder:可参见 Table 8 9。
  • 在 inference 过程中,如何选择 return-to-go:使用 dataset 中最大 return 的一倍或 5 倍。
  • Warmup tokens 为 512 ∗ 20,是 <begin> <end> 这种 token 嘛?

6 一些讨论

  • (Section 5.7)DT 为什么不需要 pessimistic value 或行为正则化?作者猜想:pessimistic value 和行为正则化是为了避免 value function approximation 带来的问题,但 DT 并不需要显式优化一个函数(?)
  • (Section 5.8)声称 DT + Go-Explore(RL exploration 方法,感觉像打表)可以帮助 online policy。
  • Credit assignment 貌似是一类工作,通过分解 reward function,使得某些“重要”状态包含了大部分 credit。

offline RL | 读读 Decision Transformer的更多相关文章

  1. 论文笔记之:Learning to Track: Online Multi-Object Tracking by Decision Making

    Learning to Track: Online Multi-Object Tracking by Decision Making ICCV   2015 本文主要是研究多目标跟踪,而 online ...

  2. 深度强化学习(Deep Reinforcement Learning)入门:RL base & DQN-DDPG-A3C introduction

    转自https://zhuanlan.zhihu.com/p/25239682 过去的一段时间在深度强化学习领域投入了不少精力,工作中也在应用DRL解决业务问题.子曰:温故而知新,在进一步深入研究和应 ...

  3. OAF_文件系列6_实现OAF导出XML文件javax.xml.parsers/transformer(案例)

    20150803 Created By BaoXinjian

  4. (zhuan) 一些RL的文献(及笔记)

    一些RL的文献(及笔记) copy from: https://zhuanlan.zhihu.com/p/25770890  Introductions Introduction to reinfor ...

  5. 【HEVC帧间预测论文】P1.5 Fast Coding Unit Size Selection for HEVC based on Bayesian Decision Rule

    Fast Coding Unit Size Selection for HEVC based on Bayesian Decision Rule <HEVC标准介绍.HEVC帧间预测论文笔记&g ...

  6. 深度学习-强化学习(RL)概述笔记

    强化学习(Reinforcement Learning)简介 强化学习是机器学习中的一个领域,强调如何基于环境而行动,以取得最大化的预期利益.其灵感来源于心理学中的行为主义理论,即有机体如何在环境给予 ...

  7. (转)RL — Policy Gradient Explained

    RL — Policy Gradient Explained 2019-05-02 21:12:57 This blog is copied from: https://medium.com/@jon ...

  8. 【强化学习RL】必须知道的基础概念和MDP

    本系列强化学习内容来源自对David Silver课程的学习 课程链接http://www0.cs.ucl.ac.uk/staff/D.Silver/web/Teaching.html 之前接触过RL ...

  9. [1] Multi-View Transformer for 3D Visual Grounding 论文精读

    参考: https://zhuanlan.zhihu.com/p/467913475 3D Visual Grounding小白调研笔记 https://zhuanlan.zhihu.com/p/34 ...

  10. 多精度 simulator 中的 RL:一篇 14 年 ICRA 的古早论文

    目录 全文快读 0 abstract 1 intro 2 related work 3 背景 & 假设 3.1 RL & KWIK(know what it knows)的背景 3.2 ...

随机推荐

  1. [转帖]将 Cloudflare 连接到互联网的代理——Pingora 的构建方式

    https://zhuanlan.zhihu.com/p/575228941 简介 今天,我们很高兴有机会在此介绍 Pingora,这是我们使用 Rust 在内部构建的新 HTTP 代理,它每天处理超 ...

  2. ThreadLocal源码解析及实战应用

    作者:京东物流 闫鹏勃 1 什么是ThreadLocal? ThreadLocal是一个关于创建线程局部变量的类. 通常情况下,我们创建的变量是可以被任何一个线程访问并修改的.而使用ThreadLoc ...

  3. [置顶] ES篇

    环境 第一篇: ES--介绍.安装 第二篇: ES--安装nodejs 第三篇: ES--插件 第四篇: ES--kibana介绍.安装 第五篇: ES--中文分词介绍.下载 第六篇: docker, ...

  4. Fabric升级示例

    Fabric v1.4.x升级至v2.2.0 本文首发于这里,转载请注明出处. 以fabric-samples v1.4.8为例,将v1.4.8升级至v2.2.0.注意,所有节点以滚动的方式进行升级, ...

  5. [LeetCode刷题记录]113 路径总和 II

    题目描述 给定一个二叉树和一个目标和,找到所有从根节点到叶子节点路径总和等于给定目标和的路径. 说明: 叶子节点是指没有子节点的节点. 难度 中等 题解 采用深度搜索优先,遍历每条从根节点到叶子节点的 ...

  6. 结构体定义及结构体粒度(alignment)

    结构体定义及结构体粒度(alignment) #pragma pack(1) typedef struct _STUDENT_INFORMATION_ { int Age; char v1; int ...

  7. 基于protobuf和httplib的在线通讯录项目框架|Protobuf应用小项目

    前言 那么这里博主先安利一些干货满满的专栏了! 首先是博主的高质量博客的汇总,这个专栏里面的博客,都是博主最最用心写的一部分,干货满满,希望对大家有帮助. 高质量博客汇总https://blog.cs ...

  8. 教你用JavaScript实现表情评级

    案例介绍 欢迎来到我的小院,我是霍大侠,恭喜你今天又要进步一点点了!我们来用JavaScript编程实战案例,做一个表情评价程序.用户打星进行评价,表情会根据具体星星数量发生变化. 案例演示 点击星星 ...

  9. (python)每日代码||2024.1.18||元组中的列表成员可以改变内容,不可以改变该列表成员

    t = ([1,2,3],[2,3,4],3) print(t) t[0][1]=9 print(t) # ~ t[2]=9#TypeError: 'tuple' object does not su ...

  10. 素数打表,洛谷P1217 [USACO1.5]回文质数 Prime Palindromes

    这道题的最后一个样例TLE(超时)了,判断素数的条件是 i*i<n 1 #include<stdio.h> 2 #include<stdlib.h> 3 #include ...