相关前文:

面向分布式强化学习的经验回放框架——Reverb: A Framework for Experience Replay

论文题目:

Reverb: A Framework for Experience Replay

地址:

https://arxiv.org/pdf/2102.04736.pdf

框架代码地址:

https://github.com/deepmind/reverb

环境安装:

pip install dm-reverb[tensorflow]

============================================

Example 1: Overlapping Trajectories

Inserting Overlapping Trajectories

import reverb
import tensorflow as tf OBSERVATION_SPEC = tf.TensorSpec([10, 10], tf.uint8)
ACTION_SPEC = tf.TensorSpec([2], tf.float32) def agent_step(unused_timestep) -> tf.Tensor:
return tf.cast(tf.random.uniform(ACTION_SPEC.shape) > .5,
ACTION_SPEC.dtype) def environment_step(unused_action) -> tf.Tensor:
return tf.cast(tf.random.uniform(OBSERVATION_SPEC.shape, maxval=256),
OBSERVATION_SPEC.dtype) # Initialize the reverb server.
simple_server = reverb.Server(
tables=[
reverb.Table(
name='my_table',
sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
remover=reverb.selectors.Fifo(),
max_size=int(1e6),
# Sets Rate Limiter to a low number for the examples.
# Read the Rate Limiters section for usage info.
rate_limiter=reverb.rate_limiters.MinSize(2),
# The signature is optional but it is good practice to set it as it
# enables data validation and easier dataset construction. Note that
# we prefix all shapes with a 3 as the trajectories we'll be writing
# consist of 3 timesteps.
signature={
'actions':
tf.TensorSpec([3, *ACTION_SPEC.shape], ACTION_SPEC.dtype),
'observations':
tf.TensorSpec([3, *OBSERVATION_SPEC.shape],
OBSERVATION_SPEC.dtype),
},
)
],
# Sets the port to None to make the server pick one automatically.
# This can be omitted as it's the default.
port=9999) # Initializes the reverb client on the same port as the server.
client = reverb.Client(f'localhost:{simple_server.port}') # Dynamically adds trajectories of length 3 to 'my_table' using a client writer. with client.trajectory_writer(num_keep_alive_refs=3) as writer:
timestep = environment_step(None)
for step in range(4):
action = agent_step(timestep)
writer.append({'action': action, 'observation': timestep})
timestep = environment_step(action) if step >= 2:
# In this example, the item consists of the 3 most recent timesteps that
# were added to the writer and has a priority of 1.5.
writer.create_item(
table='my_table',
priority=1.5,
trajectory={
'actions': writer.history['action'][-3:],
'observations': writer.history['observation'][-3:],
}
)

server端和client端可以不在同一台主机上,这个例子是server和client在同一主机上。上面例子预设server端的端口为9999。其中server端主要功能为维持经验池中数据,client端可以sample,也可以insert,上面例子中client只进行了insert操作。

server端负责数据的sample和insert操作的定义,虽然客户端调用sample操作或insert操作,但是最后的具体执行还是在server端,毕竟数据是由server端所维护的。 

关于语句:

with client.trajectory_writer(num_keep_alive_refs=3) as writer:

个人的理解是,client中的数据如果需要进行insert操作,那么需要先申请一段缓存空间的,其中缓存空间的大小定义就是上面的参数num_keep_alive_refs,而writer.append操作是将数据写入到client端的缓存中,也就是num_keep_alive_refs所定义大小的缓存空间中,writer.create_item则是执行将加入到缓存空间中的数据insert到服务端的操作。这就需要保证writer.create_item的时候数据是需要保持在缓存中的,也就是说num_keep_alive_refs需要足够大,不然缓存空间中没有对应的数据而此时执行writer.create_item则是会报错的,当然我们也可以直接将num_keep_alive_refs设置为一个足够大的数,但是这样就会造成client端内存的浪费。

num_keep_alive_refs所定义大小的client端缓存空间中数据会由于writer.append操作造成旧数据移除,比如上面例子中如果设置语句:

with client.trajectory_writer(num_keep_alive_refs=2) as writer:

就会报错,但是设置语句:

with client.trajectory_writer(num_keep_alive_refs=4) as writer:

就不会报错。

Sampling Overlapping Trajectories in TensorFlow

在同一主机上执行server端代码,如下:

import reverb
import tensorflow as tf OBSERVATION_SPEC = tf.TensorSpec([10, 10], tf.uint8)
ACTION_SPEC = tf.TensorSpec([2], tf.float32) def agent_step(unused_timestep) -> tf.Tensor:
return tf.cast(tf.random.uniform(ACTION_SPEC.shape) > .5,
ACTION_SPEC.dtype) def environment_step(unused_action) -> tf.Tensor:
return tf.cast(tf.random.uniform(OBSERVATION_SPEC.shape, maxval=256),
OBSERVATION_SPEC.dtype) # Initialize the reverb server.
simple_server = reverb.Server(
tables=[
reverb.Table(
name='my_table',
sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
remover=reverb.selectors.Fifo(),
max_size=int(1e6),
# Sets Rate Limiter to a low number for the examples.
# Read the Rate Limiters section for usage info.
rate_limiter=reverb.rate_limiters.MinSize(2),
# The signature is optional but it is good practice to set it as it
# enables data validation and easier dataset construction. Note that
# we prefix all shapes with a 3 as the trajectories we'll be writing
# consist of 3 timesteps.
signature={
'actions':
tf.TensorSpec([3, *ACTION_SPEC.shape], ACTION_SPEC.dtype),
'observations':
tf.TensorSpec([3, *OBSERVATION_SPEC.shape],
OBSERVATION_SPEC.dtype),
},
)
],
# Sets the port to None to make the server pick one automatically.
# This can be omitted as it's the default.
port=9999) # Initializes the reverb client on the same port as the server.
client = reverb.Client(f'localhost:{simple_server.port}') # Dynamically adds trajectories of length 3 to 'my_table' using a client writer. with client.trajectory_writer(num_keep_alive_refs=3) as writer:
timestep = environment_step(None)
for step in range(4):
action = agent_step(timestep)
writer.append({'action': action, 'observation': timestep})
timestep = environment_step(action) if step >= 2:
# In this example, the item consists of the 3 most recent timesteps that
# were added to the writer and has a priority of 1.5.
writer.create_item(
table='my_table',
priority=1.5,
trajectory={
'actions': writer.history['action'][-3:],
'observations': writer.history['observation'][-3:],
}
) import time
time.sleep(3333333)

并同时执行客户端代码:

import reverb

# Dataset samples sequences of length 3 and streams the timesteps one by one.
# This allows streaming large sequences that do not necessarily fit in memory.
dataset = reverb.TrajectoryDataset.from_table_signature(
server_address=f'localhost:9999',
table='my_table',
max_in_flight_samples_per_worker=10) # Batches 2 sequences together.
# Shapes of items is now [2, 3, 10, 10].
batched_dataset = dataset.batch(2) for sample in batched_dataset.take(2):
# Results in the following format.
print(sample.info.key) # ([2], uint64)
print(sample.info.probability) # ([2], float64) print(sample.data['observations']) # ([2, 3, 10, 10], uint8)
print(sample.data['actions']) # ([2, 3, 2], float32)

其中,dataset.batch(2)语句定义每次sample时batch_size的大小,这条语句含义为定义大小。

语句:for sample in batched_dataset.take(2):是设置返回的迭代器可以迭代的此数,也就是说可以迭代返回的batch的个数,这里我们设置可以返回的batch个数为2,那么for循环就可以循环两次。

===================================

 

其他相关代码见地址:

https://github.com/deepmind/reverb/blob/master/examples/demo.ipynb

https://github.com/deepmind/reverb/blob/master/examples/frame_stacking.ipynb

===================================

 
 

面向分布式强化学习的经验回放框架(使用例子Demo)——Reverb: A Framework for Experience Replay的更多相关文章

  1. 分布式强化学习基础概念(Distributional RL )

    分布式强化学习基础概念(Distributional RL) from: https://mtomassoli.github.io/2017/12/08/distributional_rl/ 1. Q ...

  2. 强化学习(十七) 基于模型的强化学习与Dyna算法框架

    在前面我们讨论了基于价值的强化学习(Value Based RL)和基于策略的强化学习模型(Policy Based RL),本篇我们讨论最后一种强化学习流派,基于模型的强化学习(Model Base ...

  3. ICML 2018 | 从强化学习到生成模型:40篇值得一读的论文

    https://blog.csdn.net/y80gDg1/article/details/81463731 感谢阅读腾讯AI Lab微信号第34篇文章.当地时间 7 月 10-15 日,第 35 届 ...

  4. 5G网络的深度强化学习:联合波束成形,功率控制和干扰协调

    摘要:第五代无线通信(5G)支持大幅增加流量和数据速率,并提高语音呼叫的可靠性.在5G无线网络中共同优化波束成形,功率控制和干扰协调以增强最终用户的通信性能是一项重大挑战.在本文中,我们制定波束形成, ...

  5. 强化学习(十八) 基于模拟的搜索与蒙特卡罗树搜索(MCTS)

    在强化学习(十七) 基于模型的强化学习与Dyna算法框架中,我们讨论基于模型的强化学习方法的基本思路,以及集合基于模型与不基于模型的强化学习框架Dyna.本文我们讨论另一种非常流行的集合基于模型与不基 ...

  6. ICML论文|阿尔法狗CTO讲座: AI如何用新型强化学习玩转围棋扑克游戏

    今年8月,Demis Hassabis等人工智能技术先驱们将来到雷锋网“人工智能与机器人创新大会”.在此,我们为大家分享David Silver的论文<不完美信息游戏中的深度强化学习自我对战&g ...

  7. 强化学习中的经验回放(The Experience Replay in Reinforcement Learning)

    一.Play it again: reactivation of waking experience and memory(Trends in Neurosciences 2010) SWR发放模式不 ...

  8. 谷歌重磅开源强化学习框架Dopamine吊打OpenAI

    谷歌重磅开源强化学习框架Dopamine吊打OpenAI 近日OpenAI在Dota 2上的表现,让强化学习又火了一把,但是 OpenAI 的强化学习训练环境 OpenAI Gym 却屡遭抱怨,比如不 ...

  9. 谷歌推出新型强化学习框架Dopamine

    今日,谷歌发布博客介绍其最新推出的强化学习新框架 Dopamine,该框架基于 TensorFlow,可提供灵活性.稳定性.复现性,以及快速的基准测试. GitHub repo:https://git ...

  10. 【强化学习】1-1-2 “探索”(Exploration)还是“ 利用”(Exploitation)都要“面向目标”(Goal-Direct)

    title: [强化学习]1-1-2 "探索"(Exploration)还是" 利用"(Exploitation)都要"面向目标"(Goal ...

随机推荐

  1. 如何基于Perl实现批量蛋白名转换为基因名?以做后续GO与KEGG分析

    众所周知,在完成蛋白组学组间差异蛋白筛选后,往往要做GO与KEGG功能富集分析,这就需要我们首先将蛋白名转换为基因名,或者找出基因ID.将蛋白名转化为基因名可能涉及不同的转换工具或数据库,这里有几种常 ...

  2. 高性能版本的零内存分配LikeString函数(ZeroMemAllocLikeOperator)

    继上一篇文章在.NET Core,除了VB的LikeString,还有其它方法吗?(四种LikeString实现分享)分享了四种实现方式,笔者对这四种实现方式,不管是执行性能还是内存分配性能上,都不太 ...

  3. mybatis sqlmap sql in 查询

    <select id="selectBlogs" parameterType="map"> SELECT * FROM blog WHERE use ...

  4. Linux系统与网络管理

    0. 背景 0.1 Unix Unix诞生于1969年 特点 多任务 多用户 多平台 保护模式 可移植操作系统接口(POSIX) 0.2 Linux 与Unix关系 类Unix系统,完全按照Unix的 ...

  5. DotNet Web应用单文件部署系列

    目录 一.    pubxml文件配置 二.    打包wwwroot文件夹 三.    混淆dll文件 四.    csproj文件配置 五.    批处理 六.    Windows服务安装 七. ...

  6. Pytest 失败重运行

    需安装第三方插件:pytest-rerun.pytest-rerunfailures 失败重试和失败重运行的区别 失败重试:[--reruns=1],用例执行失败后,会立即开始重试一次此用例,再执行下 ...

  7. 【WPF】根据选项值显示不同的编辑控件(使用DataTemplateSelector)

    接了一个小杂毛项目,大概情形是这样的:ZWT先生开的店是卖拆片机的,Z先生不仅卖机器,还贴心地提供一项服务:可以根据顾客需要修改两个电机的转向和转速(机器厂家有给SDK的,但Z自己不会写程序).厂家有 ...

  8. 【论文阅读】RAL2022: Make it Dense: Self-Supervised Geometric Scan Completion of Sparse 3D LiDAR Scans in Large Outdoor Environments

    0. 参考与前言 论文链接:https://ieeexplore.ieee.org/document/9812507 代码链接:https://github.com/PRBonn/make_it_de ...

  9. Nuxt3 的生命周期和钩子函数(八)

    title: Nuxt3 的生命周期和钩子函数(八) date: 2024/6/30 updated: 2024/6/30 author: cmdragon excerpt: 摘要:本文介绍了Nuxt ...

  10. Excel插件之连接数据数据库秒数处理,办公轻松化

    接上文,对excel连接数据库需求的进一步优化: Excel 更改数据同步更新到Mysql数据库 1.通过mysql for excel 插件的思路,了解到一个新的插件 sqlcel,通过这个插件ex ...