最近在复现 PPO 跑 MiniGrid,记录一下…

这里跑的环境是 Empty-5x5 和 8x8,都是简单环境,主要验证 PPO 实现是否正确。

01 Proximal policy Optimization(PPO)

(参考:知乎 | Proximal Policy Optimization (PPO) 算法理解:从策略梯度开始

首先,策略梯度方法 的梯度形式是

\[\nabla_\theta J(\theta)\approx
\frac1n \sum_{i=0}^{n-1} R(\tau_i)
\sum_{t=0}^{T-1} \nabla_\theta \log \pi_\theta(a_t|s_t)
\tag1
\]

然而,传统策略梯度方法容易一步走的太多,以至于越过了中间比较好的点(在参考知乎博客里称为 overshooting)。一个直观的想法是限制策略每次不要更新太多,比如去约束 新策略 旧策略之间的 KL 散度(公式是 plog(p/q)):

\[D_{KL}(\pi_\theta | \pi_{\theta+\Delta \theta}) = \mathbb E_{s,a}
\pi_\theta(a|s)\log\frac{\pi_\theta(a|s)}{\pi_{\theta+\Delta \theta}(a|s)} \le \epsilon
\tag2
\]

我们把这个约束进行拉格朗日松弛,将它变成一个惩罚项:

\[\Delta\theta^* = \arg\max_{\Delta\theta} J(\theta+\Delta\theta) -
\lambda [D_{KL}(\pi_\theta | \pi_{\theta+\Delta \theta})-\epsilon]
\tag3
\]

然后再使用一些数学近似技巧,可以得到自然策略梯度(NPG)算法。

NPG 算法貌似还有种种问题,比如 KL 散度的约束太紧,导致每次更新后的策略性能没有提升。我们希望每次策略更新后都带来性能提升,因此计算 新策略 旧策略之间 预期回报的差异。这里采用计算 advantage 的方式:

\[J(\pi_{\theta+\Delta\theta})=J(\pi_{\theta})+\mathbb E_{\tau\sim\pi_{\theta+\Delta\theta}}\sum_{t=0}^\infty
\gamma^tA^{\pi_{\theta}}(s_t,a_t)
\tag{4}
\]

其中优势函数(advantage)的定义是:

\[A^{\pi_{\theta}}(s_t,a_t)=\mathbb E(Q^{\pi_{\theta}}(s_t,a_t)-V^{\pi_{\theta}}(s_t))
\tag{5}
\]

在公式 (4) 中,我们计算的 advantage 是在 新策略 的期望下的。但是,在新策略下蒙特卡洛采样(rollout)来算 advantage 期望太麻烦了,因此我们在原策略下 rollout,并进行 importance sampling,假装计算的是新策略下的 advantage。这个 advantage 被称为替代优势(surrogate advantage):

\[\mathcal{L}_{\pi_{\theta}}\left(\pi_{\theta+\Delta\theta}\right) =
J\left(\pi_{\theta+\Delta\theta}\right)-J\left(\pi_{\theta}\right)\approx E_{s\sim\rho_{\pi\theta}}\frac{\pi_{\theta+\Delta\theta}(a\mid s)}{\pi_{\theta}(a\mid s)} A^{\pi_{\theta}}(s, a)
\tag6
\]

所产生的近似误差,貌似可以用两种策略之间最坏情况的 KL 散度表示:

\[J(\pi_{\theta+\Delta\theta})-J(\pi_{\theta})\geq\mathcal{L}_{\pi\theta}(\pi_{\theta+\Delta\theta})-CD_{KL}^{\max}(\pi_{\theta}||\pi_{\theta+\Delta\theta})
\tag7
\]

其中 C 是一个常数。这貌似就是 TRPO 的单调改进定理,即,如果我们改进下限 RHS,我们也会将目标 LHS 改进至少相同的量。

基于 TRPO 算法,我们可以得到 PPO 算法。PPO Penalty 跟 TRPO 比较相近:

\[\Delta\theta^{*}=\underset{\Delta\theta}{\text{argmax}}
\Big[\mathcal{L}_{\theta+\Delta\theta}(\theta+\Delta\theta)-\beta\cdot \mathcal{D}_{KL}(\pi_{\theta}\parallel\pi_{\theta+\Delta\theta})\Big]
\tag 8
\]

其中,KL 散度惩罚的 β 是启发式确定的:PPO 会设置一个目标散度 \(\delta\),如果最终更新的散度超过目标散度的 1.5 倍,则下一次迭代我们将加倍 β 来加重惩罚。相反,如果更新太小,我们将 β 减半,从而扩大信任域。

接下来是 PPO Clip,这貌似是目前最常用的 PPO。PPO Penalty 用 β 来惩罚策略变化,而 PPO Clip 与此不同,直接限制策略可以改变的范围。我们重新定义 surrogate advantage:

\[\begin{aligned}
\mathcal{L}_{\pi_{\theta}}^{CLIP}(\pi_{\theta_{k}}) = \mathbb E_{\tau\sim\pi_{\theta}}\bigg[\sum_{t=0}^{T}
\min\Big( & \rho_{t}(\pi_{\theta}, \pi_{\theta_{k}})A_{t}^{\pi_{\theta_{k}}},
\\
& \text{clip} (\rho_{t}(\pi_{\theta},\pi_{\theta_{k}}), 1-\epsilon, 1+\epsilon) A_{t}^{\pi_{\theta_{k}}}
\Big)\bigg]
\end{aligned}
\tag 9
\]

其中, \(\rho_{t}\) 为重要性采样的 ratio:

\[\rho_{t}(\theta)=\frac{\pi_{\theta}(a_{t}\mid s_{t})}{\pi_{\theta_{k}}(a_{t}\mid s_{t})}
\tag{10}
\]

公式 (9) 中,min 括号里的第一项是 ratio 和 advantage 相乘,代表新策略下的 advantage;min 括号里的第二项是对 ration 进行的 clip 与 advantage 的相乘。这个 min 貌似可以限制策略变化不要太大。

02 如何复现 PPO(参考 stable baselines3 和 clean RL)

代码主要结构如下,以 stable baselines3 为例:(仅保留主要结构,相当于伪代码,不保证正确性)

import torch
import torch.nn.functional as F
import numpy as np # 1. collect rollout
self.policy.eval()
rollout_buffer.reset()
while not done:
actions, values, log_probs = self.policy(self._last_obs)
new_obs, rewards, dones, infos = env.step(clipped_actions)
rollout_buffer.add(
self._last_obs, actions, rewards,
self._last_episode_starts, values, log_probs,
)
self._last_obs = new_obs
self._last_episode_starts = dones with torch.no_grad():
# Compute value for the last timestep
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) # 2. policy optimization
for rollout_data in self.rollout_buffer.get(self.batch_size):
actions = rollout_data.actions
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
advantages = rollout_data.advantages
# Normalize advantage
if self.normalize_advantage and len(advantages) > 1:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration
ratio = torch.exp(log_prob - rollout_data.old_log_prob) # clipped surrogate loss
policy_loss_1 = advantages * ratio
policy_loss_2 = advantages * torch.clamp(ratio, 1 - clip_range, 1 + clip_range)
policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean() # Value loss using the TD(gae_lambda) target
value_loss = F.mse_loss(rollout_data.returns, values_pred) # Entropy loss favor exploration
entropy_loss = -torch.mean(entropy) loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss # Optimization step
self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()

大致流程:收集当前策略的 rollout → 计算 advantage → 策略优化。

计算 advantage 是由 rollout_buffer.compute_returns_and_advantage 函数实现的:

rb = rollout_buffer
last_gae_lam = 0
for step in reversed(range(buffer_size)):
if step == buffer_size - 1:
next_non_terminal = 1.0 - dones.astype(np.float32)
next_values = last_values
else:
next_non_terminal = 1.0 - rb.episode_starts[step + 1]
next_values = rb.values[step + 1]
delta = rb.rewards[step] + gamma * next_values * next_non_terminal - rb.values[step] # (1)
last_gae_lam = delta + gamma * gae_lambda * next_non_terminal * last_gae_lam # (2)
rb.advantages[step] = last_gae_lam
rb.returns = rb.advantages + rb.values

其中,

  • (1) 行通过类似于 TD error 的形式(A = r + γV(s') - V(s)),计算当前 t 时刻的 advantage;
  • (2) 行则是把 t+1 时刻的 advantage 乘 gamma 和 gae_lambda 传递过来。

03 记录一些踩坑经历

  1. PPO 在收集 rollout 的时候,要在分布里采样,而非采用 argmax 动作,否则没有 exploration。(PPO 在分布里采样 action,这样来保证探索,而非使用 epsilon greedy 等机制;听说 epsilon greedy 机制是 value-based 方法用的)
  2. 如果 policy 网络里有(比如说)batch norm,rollout 时应该把 policy 开 eval 模式,这样就不会出错。
  3. (但是,不要加 batch norm,加 batch norm 性能就不好了。听说 RL 不能加 batch norm)
  4. minigrid 简单环境,RNN 加不加貌似都可以(?)
  5. 在算 entropy loss 的时候,要用真 entropy,从 Categorical 分布里得到的 entropy;不要用 -logprob 近似的,不然会导致策略分布 熵变得很小 炸掉。

RL 基础 | 如何复现 PPO,以及一些踩坑经历的更多相关文章

  1. TiDB 深度实践之旅--真实“踩坑”经历

    美团点评 TiDB 深度实践之旅(9000 字长文 / 真实“踩坑”经历) 4   PingCAP · 154 天前 · 3956 次点击 这是一个创建于 154 天前的主题,其中的信息可能已经有所发 ...

  2. nginx搭建网站踩坑经历

    为了更好的阅读体验,请访问我的个人博客 前言 早上刷抖音刷到一个只需要三步的nginx搭建教程(视频地址),觉得有些离谱,跟着复现了一遍,果然很多地方不严谨并且省略了大量步骤,对于很多不了解linux ...

  3. 『审慎』.Net4.6 Task 异步函数 比 同步函数 慢5倍 踩坑经历

    异步Task简单介绍 本标题有点 哗众取宠,各位都别介意(不排除个人技术能力问题) —— 接下来:我将会用一个小Demo 把 本文思想阐述清楚. .Net 4.0 就有了 Task 函数 —— 异步编 ...

  4. Net4.6 Task 异步函数 比 同步函数 慢5倍 踩坑经历

    Net4.6 Task 异步函数 比 同步函数 慢5倍 踩坑经历 https://www.cnblogs.com/shuxiaolong/p/DotNet_Task_BUG.html 异步Task简单 ...

  5. myeclipse使用db-brower连接到sqlserver2012踩坑经历

    myeclipse使用db-brower连接到sqlserver踩坑经历 首先得建立个角色 右键->创建登录名 权限开大点 连接设置 Driver template选择我选这个,格式按照我的写 ...

  6. sqlserver安装和踩坑经历

    sqlserver安装和踩坑经历 下载 下载 安装 大致是按照这个来的 安装教程 出错 windows系统安装软件弹出"Windows installer service could not ...

  7. Dubbo 服务 IP 注册错误踩坑经历

    个人博客地址 studyidea.cn,点击查看更多原创文章 踩坑 公司最近新建一个机房,需要将现有系统同步部署到新机房,部署完成之后,两地机房同时对提供服务.系统架构如下图: 这个系统当前对外采用 ...

  8. 使用BeanUtils.copyProperties踩坑经历

    1. 原始转换 提起对象转换,每个程序员都不陌生,比如项目中经常涉及到的DO.DTO.VO之间的转换,举个例子,假设现在有个OrderDTO,定义如下所示: public class OrderDTO ...

  9. 【踩坑经历】一次Asp.NET小网站部署踩坑和解决经历

    2013年给1个大学的小客户部署过一个小型的Asp.NET网站,非常小,用的sqlite数据库,今年人家说要换台服务器,要重新部署一下,好吧,虽然早就过了服务时间,但无奈谁叫人家是客户了,二话不说,上 ...

  10. RocketMQ同一个消费者唯一Topic多个tag踩坑经历

    最近做的项目的一个版本需求中,需要用到MQ,对数据记录进行异步落库,这样可以减轻数据库的压力,同时可以抗住大量的数据落库.这里需要说明一下本人用到的MQ是公司自己在阿里的RokectMQ的基础上进行封 ...

随机推荐

  1. 【Appium】之自动化定位总结

    一.同级定位时,先定位上级 我想定位[必填]框,我先定位[姓名]的同一个上级 self.driver.find_element(MobileBy.XPATH,"//*[contains(@t ...

  2. 一个小小空格问题引起的bug

    程序员会遇到一种情况,一个bug排查到最后是由一个很小的问题导致的.在昨天的日常搬砖中遇到一个问题,耽搁了我大半天的时间,最后查明原因让我很无语. 首先介绍一下背景,我是做算法模型训练,目前手上的工作 ...

  3. 从数据洞察到智能决策:合合信息&infiniflow RAG技术的实战案例分享

    从数据洞察到智能决策:合合信息&infiniflow RAG技术的实战案例分享 标题取自 LLamaIndex,这个内容最早提出于今年 2 月份 LlamaIndex 官方博客.从 22 年 ...

  4. ASP.NET Core – Try Preview

    前言 .NET 7 已经来到 RC 阶段了. 通常 RC 就是我们 (写库的人) 要入场的时候了. 有发现 Bug 要尽可能在这段期间提交. 不然后患无穷. 这篇主要就是来讲讲如果测试 RC 版本的 ...

  5. CSS & JS Effect – Button Hover Bling Bling Effect

    效果 原理 一眼看上去, background 有渐变颜色 linear-gradient. 当 hover in 的时候有一束白光, 从右边移动到左边. hover out 则是反过来. 它其实是通 ...

  6. SimpleAISearch:C# + DuckDuckGo 实现简单的AI搜索

    最近AI搜索很火爆,有Perplexity.秘塔AI.MindSearch.Perplexica.memfree.khoj等等. 在使用大语言模型的过程中,或许你也遇到了这种局限,就是无法获取网上最新 ...

  7. QT原理与源码分析之QT反射机制原理

    QT反射机制原理 本文将介绍QT反射机制创建QT对象实例的原理和流程以及源代码. 文章目录 QT反射机制创建QT对象实例 原理 流程 源码 QT反射机制创建QT对象实例 QT框架提供的基于元对象的反射 ...

  8. Spring技术书的代码资源下载

    我是清华社编辑,这些资源获得作者授权,免费提供给读者个人学习使用.禁止任何形式的商用. 二维码用微信扫,按提示填写你的邮箱,转到电脑上打开邮箱下载.清华国企网盘,比较快速.安全.放心下载. 百度网盘链 ...

  9. Android dtbo(1) dto简介

    设备树 (DT, Device Tree) 是用于描述 non-discoverable(google这样写的,意思应该就是硬件信息看不到) 硬件的命名节点和属性构成的一种数据结构.操作系统(例如在 ...

  10. 2021年11月数据库排行解读:openGauss跃居第三,人大金仓晋身前十

    2021年11月墨天轮国产数据库流行度排行榜出炉,本月前三的数据库产品分别是:TiDB.达梦.openGauss,openGauss 数据库是首次跻身前三强. TiDB 自2020年以来,持续稳居榜首 ...