强化学习的REIINFORCE算法和交叉熵RL算法
注意:
本文并不讲REINFORCE算法,而是讲强化学习的交叉熵算法,关于REINFORCE算法可以参看:
https://www.cnblogs.com/devilmaycry812839668/p/15889282.html
==========================================
强化学习有多种分类方法,其中一类分法为:
- 基于值函数的。该种类型的强化学习算法,比较有代表的基础算法有Q-learning算法、Sarsa算法等。
- 基于策略梯度的。该种类型的强化学习算法,比较有代表的基础算法有REINFORCE、交叉熵RL算法等。
本文主要讲交叉熵RL算法。交叉熵RL不同于REINFORCE算法,损失函数中是不使用奖励值的。交叉熵RL在每次和环境交互采集一定数量的episodes数据后根据奖励值选择其中一定比例的episodes数据,然后根据这些选定数据中动作的选择和对应的概率来进行交叉熵损失计算。如果在选定的episodes数据中有某个step,该step中状态可选择的动作为a0,a1,a2,a3这四个动作,假设agent最终选择的动作为a2,计算损失函数时得到在该step下选择a2的概率为p2,那么计算时使用交叉熵函数则可以写为 -(0*logp0 + 0*logp1 + 1*logp2 + 0*logp3 ) = -logp2 。在对episodes数据进行选择时,我们可以根据最终奖励值的大小选择一定百分比的episodes,如选择最好的30%的episodes (在下面代码中百分位数设为70,就是选择最好的30%数据)。
需要注意的是交叉熵RL算法是十分基础的RL算法,缺点也很多,现在很少会有人使用,了解这个算法重要意义在于学习。在交叉熵RL算法可以使用对以往表现好的episodes数据进行保存,然后和新获得的数据一起进行再次训练,该种方式一般叫做保留精英操作。
给出CartPole环境下的一个交叉熵RL算法的代码:(Pytorch框架)
import gym
from collections import namedtuple
import numpy as np
from tensorboardX import SummaryWriter import torch
import torch.nn as nn
import torch.optim as optim HIDDEN_SIZE = 128
BATCH_SIZE = 16
PERCENTILE = 70 class Net(nn.Module):
def __init__(self, obs_size, hidden_size, n_actions):
super(Net, self).__init__()
self.net = nn.Sequential(
nn.Linear(obs_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, n_actions)
) def forward(self, x):
return self.net(x) Episode = namedtuple('Episode', field_names=['reward', 'steps'])
EpisodeStep = namedtuple('EpisodeStep', field_names=['observation', 'action']) def iterate_batches(env, net, batch_size):
batch = []
episode_reward = 0.0
episode_steps = []
obs = env.reset()
sm = nn.Softmax(dim=1)
while True:
obs_v = torch.FloatTensor([obs])
act_probs_v = sm(net(obs_v))
act_probs = act_probs_v.data.numpy()[0]
action = np.random.choice(len(act_probs), p=act_probs)
next_obs, reward, is_done, _ = env.step(action)
episode_reward += reward
step = EpisodeStep(observation=obs, action=action)
episode_steps.append(step)
if is_done:
e = Episode(reward=episode_reward, steps=episode_steps)
batch.append(e)
episode_reward = 0.0
episode_steps = []
next_obs = env.reset()
if len(batch) == batch_size:
yield batch
batch = []
obs = next_obs def filter_batch(batch, percentile):
rewards = list(map(lambda s: s.reward, batch))
reward_bound = np.percentile(rewards, percentile)
reward_mean = float(np.mean(rewards)) train_obs = []
train_act = []
for reward, steps in batch:
if reward < reward_bound:
continue
train_obs.extend(map(lambda step: step.observation, steps))
train_act.extend(map(lambda step: step.action, steps)) train_obs_v = torch.FloatTensor(train_obs)
train_act_v = torch.LongTensor(train_act)
return train_obs_v, train_act_v, reward_bound, reward_mean if __name__ == "__main__":
env = gym.make("CartPole-v0")
# env = gym.wrappers.Monitor(env, directory="mon", force=True)
obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n net = Net(obs_size, HIDDEN_SIZE, n_actions)
objective = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), lr=0.01)
writer = SummaryWriter(comment="-cartpole") for iter_no, batch in enumerate(iterate_batches(
env, net, BATCH_SIZE)):
obs_v, acts_v, reward_b, reward_m = \
filter_batch(batch, PERCENTILE)
optimizer.zero_grad()
action_scores_v = net(obs_v)
loss_v = objective(action_scores_v, acts_v)
loss_v.backward()
optimizer.step()
print("%d: loss=%.3f, reward_mean=%.1f, rw_bound=%.1f" % (
iter_no, loss_v.item(), reward_m, reward_b))
writer.add_scalar("loss", loss_v.item(), iter_no)
writer.add_scalar("reward_bound", reward_b, iter_no)
writer.add_scalar("reward_mean", reward_m, iter_no)
if reward_m > 199:
print("Solved!")
break
writer.close()
============================================
强化学习的REIINFORCE算法和交叉熵算法作为比较基础的算法经常作为baseline被提及,关于REIINFORCE算法可以参看:
https://www.cnblogs.com/devilmaycry812839668/p/15889282.html
============================================
强化学习的REIINFORCE算法和交叉熵RL算法的更多相关文章
- 强化学习中REIINFORCE算法和AC算法在算法理论和实际代码设计中的区别
背景就不介绍了,REINFORCE算法和AC算法是强化学习中基于策略这类的基础算法,这两个算法的算法描述(伪代码)参见Sutton的reinforcement introduction(2nd). A ...
- 统计学习:逻辑回归与交叉熵损失(Pytorch实现)
1. Logistic 分布和对率回归 监督学习的模型可以是概率模型或非概率模型,由条件概率分布\(P(Y|\bm{X})\)或决 策函数(decision function)\(Y=f(\bm{X} ...
- 强化学习(五)—— 策略梯度及reinforce算法
1 概述 在该系列上一篇中介绍的基于价值的深度强化学习方法有它自身的缺点,主要有以下三点: 1)基于价值的强化学习无法很好的处理连续空间的动作问题,或者时高维度的离散动作空间,因为通过价值更新策略时是 ...
- 强化学习(十七) 基于模型的强化学习与Dyna算法框架
在前面我们讨论了基于价值的强化学习(Value Based RL)和基于策略的强化学习模型(Policy Based RL),本篇我们讨论最后一种强化学习流派,基于模型的强化学习(Model Base ...
- 深度强化学习day01初探强化学习
深度强化学习 基本概念 强化学习 强化学习(Reinforcement Learning)是机器学习的一个重要的分支,主要用来解决连续决策的问题.强化学习可以在复杂的.不确定的环境中学习如何实现我们设 ...
- <强化学习>开门帖
(本系列只用作本人笔记,如果看官是以新手开始学习RL,不建议看我写的笔记昂) 今天是2020年2月7日,开始二刷david silver ulc课程.https://www.youtube.com/w ...
- softmax交叉熵损失函数求导
来源:https://www.jianshu.com/p/c02a1fbffad6 简单易懂的softmax交叉熵损失函数求导 来写一个softmax求导的推导过程,不仅可以给自己理清思路,还可以造福 ...
- 机器学习之路:tensorflow 深度学习中 分类问题的损失函数 交叉熵
经典的损失函数----交叉熵 1 交叉熵: 分类问题中使用比较广泛的一种损失函数, 它刻画两个概率分布之间的距离 给定两个概率分布p和q, 交叉熵为: H(p, q) = -∑ p(x) log q( ...
- 强化学习调参技巧二:DDPG、TD3、SAC算法为例:
1.训练环境如何正确编写 强化学习里的 env.reset() env.step() 就是训练环境.其编写流程如下: 1.1 初始阶段: 先写一个简化版的训练环境.把任务难度降到最低,确保一定能正常训 ...
- 深度学习基础5:交叉熵损失函数、MSE、CTC损失适用于字识别语音等序列问题、Balanced L1 Loss适用于目标检测
深度学习基础5:交叉熵损失函数.MSE.CTC损失适用于字识别语音等序列问题.Balanced L1 Loss适用于目标检测 1.交叉熵损失函数 在物理学中,"熵"被用来表示热力学 ...
随机推荐
- 从零开始写 Docker(十八)---容器网络实现(下):为容器插上”网线“
本文为从零开始写 Docker 系列第十八篇,利用 linux 下的 Veth.Bridge.iptables 等等相关技术,构建容器网络模型,为容器插上"网线". 完整代码见:h ...
- Python使用.NET开发的类库来提高你的程序执行效率
Python由于本身的特性原因,执行程序期间可能效率并不是很理想.在某些需要自己提高一些代码的执行效率的时候,可以考虑使用C#.C++.Rust等语言开发的库来提高python本身的执行效率.接下来, ...
- 海量数据处理利器 Roaring BitMap 原理介绍
作者:来自 vivo 互联网服务器团队- Zheng Rui 本文结合个人理解梳理了BitMap及Roaring BitMap的原理及使用,分别主要介绍了Roaring BitMap的存储方式及三种c ...
- echarts 各种特效图
饼图标签展示数值 配置项: option = { title: { text: '项目时间分布', left: 'center' }, tooltip: { trigger: 'item', form ...
- spring使用RedisCacheManager管理key的一些问题
spring可以很好地管理各种内存的快速缓存. 这些常见的内存缓存库实现方式有redis,Ehcache. 本文阐述的是redis,毕竟这个东西相当容易使用. spring通过 org.springf ...
- I2S 总线学习:1-有关概念
背景 I2S总线 是一种常见的总线,也是需要掌握的. 概念 I2S(Inter-IC Sound)总线, 又称 集成电路内置音频总线,是飞利浦公司为数字音频设备之间的音频数据传输而制定的一种总线标准, ...
- Python 潮流周刊#58:最快运行原型的语言(摘要)
本周刊由 Python猫 出品,精心筛选国内外的 250+ 信息源,为你挑选最值得分享的文章.教程.开源项目.软件工具.播客和视频.热门话题等内容.愿景:帮助所有读者精进 Python 技术,并增长职 ...
- js脚本化css
脚本化CSS 我们刚讲过如何获取和设置行内样式的值,但是我们开发不会所有样式都写在行内,同时js没法获取内嵌样式表和外部样式表中的值. 事实上DOM提供了可靠的API,得到计算后的样式. 1. 获取计 ...
- helloworld - 程序员的第一个社区终于来了
helloworld - 程序员的第一个社区终于来了 csdn事件 CSDN旗下的GitCode最近因为一种极其不道德的行为引起了开发者的广泛愤怒和抗议.CSDN在没有通知或征求开发者同意的情况下,悄 ...
- Java类全路径冲突解决方法
1. 问题 今天在开发中遇到这样一个问题,A同事在导入了我们的实验SDK后,发现实验无法正常获取,查看日志发现了NoClassDefFoundError异常,无法加载的的类中逻辑比较简单,只依赖了另外 ...