tensorflow的官方强化学习库agents的相关内容及一些注意事项
源代码地址:
https://github.com/tensorflow/agents
TensorFlow给出的官方文档说明:
https://tensorflow.google.cn/agents
相关视频:
https://www.youtube.com/watch?v=U7g7-Jzj9qo
https://www.youtube.com/watch?v=tAOApRQAgpc
https://www.youtube.com/watch?v=52DTXidSVWc&list=PLQY2H8rRoyvxWE6bWx8XiMvyZFgg_25Q_&index=2
-----------------------------------------------------------
框架实现的算法:
论文1:
论文3:
论文4:
论文5:
论文6:
论文7:
论文8:
论文9:
论文10:
====================================
1. gym的环境版本有要求,给出具体安装及Atari的安装:
pip install gym[atari]==0.23.0
pip install gym[accept-rom-license]
=====================================
2. 代码的逻辑bug
tf_agents/specs/array_spec.py 代码bug:
def sample_bounded_spec(spec, rng):
"""Samples the given bounded spec. Args:
spec: A BoundedSpec to sample.
rng: A numpy RandomState to use for the sampling. Returns:
An np.array sample of the requested spec.
"""
tf_dtype = tf.as_dtype(spec.dtype)
low = spec.minimum
high = spec.maximum if tf_dtype.is_floating:
if spec.dtype == np.float64 and np.any(np.isinf(high - low)):
# The min-max interval cannot be represented by the np.float64. This is a
# problem only for np.float64, np.float32 works as expected.
# Spec bounds are set to read only so we can't use argumented assignment.
low = low / 2
high = high / 2
return rng.uniform(
low,
high,
size=spec.shape,
).astype(spec.dtype) else:
if spec.dtype == np.int64 and np.any(high - low < 0):
# The min-max interval cannot be represented by the tf_dtype. This is a
# problem only for int64.
low = low / 2
high = high / 2 if np.any(high < tf_dtype.max):
high = np.where(high < tf_dtype.max, high + 1, high)
elif spec.dtype != np.int64 or spec.dtype != np.uint64:
# We can still +1 the high if we cast it to the larger dtype.
high = high.astype(np.int64) + 1 if low.size == 1 and high.size == 1:
return rng.randint(
low,
high,
size=spec.shape,
dtype=spec.dtype,
)
else:
return np.reshape(
np.array([
rng.randint(low, high, size=1, dtype=spec.dtype)
for low, high in zip(low.flatten(), high.flatten())
]), spec.shape)
这个代码的意思就是在给定区间能进行均匀抽样,但是由于区间可能过大因此导致无法使用库函数抽样,因此需要对区间进行压缩。
当抽样的数据类型为np.float64时,由代码:
np.array(np.zeros(10), dtype=np.float64)+np.finfo(np.float64).max-np.finfo(np.float64).min

np.random.uniform(size=10, low=np.finfo(np.float64).min, high=np.finfo(np.float64).max)

可以知道,当类型为np.float64时,如果抽样区间过大会(超出数值表示范围)导致无法抽样,因此进行压缩区间:

当数据类型为np.float32时,虽然也会存在超出表示范围的问题:
np.array(np.zeros(10), dtype=np.float32)+np.finfo(np.float32).max-np.finfo(np.float32).min

但是由于函数 np.random.uniform 的计算中会把np.float32转为np.float64,因此不会出现报错,如下:
np.random.uniform(size=10, low=np.finfo(np.float32).min, high=np.finfo(np.float32).max)

---------------------------------------------------
当数值类型为int时,区间访问过大的检测代码为:
np.any(high - low < 0)
原因在意np.float类型数值超出表示范围会表示为infi变量,但是int类型则会以溢出形式表现,如:

但是在使用numpy.random.randint 函数时,即使范围为最大范围也没有报错:
np.random.randint(size=10000, low=tf.as_dtype(np.int64).min, high=tf.as_dtype(np.int64).max)

np.random.randint(size=10000, low=np.iinfo(np.int64).min, high=np.iinfo(np.int64).max)

而且即使由于high值是取开区间的,我们对high值加1以后也没有报错:

但是需要注意,此时传给np.random.randint函数中的low和high数值都为python数据类型int而不是numpy中的np.int64,下面我们看下numpy.float64类型是否会溢出:
当以数组形式传递最高high值并使其保持np.float64类型,发现使用high+1就会溢出:

可以看到使用最大范围+1作为high值会导致报错:

可以看到在使用numpy.random.randint时对上下限还是要注意的,虽然numpy.random.randint对上限是开区间,但是+1操作是很可能引起溢出错误的。
这也就是为什么 high+1操作之前要做判断了,如下:

不过如果数据类型不为np.int64,并且也不为np.uint64,那么我们依然可以把high值转为np.int64后在+1 ,但是上面的逻辑判断是有一定问题的,这些修正后如下:

总的修改后的代码为:
def sample_bounded_spec(spec, rng):
"""Samples the given bounded spec. Args:
spec: A BoundedSpec to sample.
rng: A numpy RandomState to use for the sampling. Returns:
An np.array sample of the requested spec.
"""
tf_dtype = tf.as_dtype(spec.dtype)
low = spec.minimum
high = spec.maximum if tf_dtype.is_floating:
if spec.dtype == np.float64 and np.any(np.isinf(high - low)):
# The min-max interval cannot be represented by the np.float64. This is a
# problem only for np.float64, np.float32 works as expected.
# Spec bounds are set to read only so we can't use argumented assignment.
low = low / 2
high = high / 2
return rng.uniform(
low,
high,
size=spec.shape,
).astype(spec.dtype) else:
if spec.dtype == np.int64 and np.any(high - low < 0):
# The min-max interval cannot be represented by the tf_dtype. This is a
# problem only for int64.
low = low / 2
high = high / 2 if np.any(high < tf_dtype.max):
high = np.where(high < tf_dtype.max, high + 1, high)
elif spec.dtype != np.int64 and spec.dtype != np.uint64:
# We can still +1 the high if we cast it to the larger dtype.
high = high.astype(np.int64) + 1 if low.size == 1 and high.size == 1:
return rng.randint(
low,
high,
size=spec.shape,
dtype=spec.dtype,
)
else:
return np.reshape(
np.array([
rng.randint(low, high, size=1, dtype=spec.dtype)
for low, high in zip(low.flatten(), high.flatten())
]), spec.shape)
加入对np.uint64类型的判断:
def sample_bounded_spec(spec, rng):
"""Samples the given bounded spec. Args:
spec: A BoundedSpec to sample.
rng: A numpy RandomState to use for the sampling. Returns:
An np.array sample of the requested spec.
"""
tf_dtype = tf.as_dtype(spec.dtype)
low = spec.minimum
high = spec.maximum if tf_dtype.is_floating:
if spec.dtype == np.float64 and np.any(np.isinf(high - low)):
# The min-max interval cannot be represented by the np.float64. This is a
# problem only for np.float64, np.float32 works as expected.
# Spec bounds are set to read only so we can't use argumented assignment.
low = low / 2
high = high / 2
return rng.uniform(
low,
high,
size=spec.shape,
).astype(spec.dtype) else:
if spec.dtype == np.int64 and np.any(high - low < 0):
# The min-max interval cannot be represented by the tf_dtype. This is a
# problem only for int64.
low = low / 2
high = high / 2 if spec.dtype == np.uint64 and np.any(high >= np.iinfo(np.int64).max):
low = low / 2
high = high / 2 if np.any(high < tf_dtype.max):
high = np.where(high < tf_dtype.max, high + 1, high)
elif spec.dtype != np.int64 and spec.dtype != np.uint64:
# We can still +1 the high if we cast it to the larger dtype.
high = high.astype(np.int64) + 1 if low.size == 1 and high.size == 1:
return rng.randint(
low,
high,
size=spec.shape,
dtype=spec.dtype,
)
else:
return np.reshape(
np.array([
rng.randint(low, high, size=1, dtype=spec.dtype)
for low, high in zip(low.flatten(), high.flatten())
]), spec.shape)
===========================================
tf-agents框架的交互逻辑代码:
===========================================








tensorflow的官方强化学习库agents的相关内容及一些注意事项的更多相关文章
- 学习笔记之html5相关内容
写一下昨天学习的html5的相关内容,首先谈下初次接触html5的感受.以前总是听说html5是如何的强大,如何的将要改变世界.总是充满了神秘感.首先来谈一下我接触的第一个属性是 input的里面的 ...
- 强化学习之四:基于策略的Agents (Policy-based Agents)
本文是对Arthur Juliani在Medium平台发布的强化学习系列教程的个人中文翻译,该翻译是基于个人分享知识的目的进行的,欢迎交流!(This article is my personal t ...
- 强化学习之三:双臂赌博机(Two-armed Bandit)
本文是对Arthur Juliani在Medium平台发布的强化学习系列教程的个人中文翻译,该翻译是基于个人分享知识的目的进行的,欢迎交流!(This article is my personal t ...
- DQN 强化学习
pytorch比tenserflow简单. 所以我们模仿用tensorflow写的强化学习. 学习资料: 本节的全部代码 Tensorflow 的 100行 DQN 代码 我制作的 DQN 动画简介 ...
- 强化学习之七:Visualizing an Agent’s Thoughts and Actions
本文是对Arthur Juliani在Medium平台发布的强化学习系列教程的个人中文翻译,该翻译是基于个人分享知识的目的进行的,欢迎交流!(This article is my personal t ...
- 强化学习之六:Deep Q-Network and Beyond
本文是对Arthur Juliani在Medium平台发布的强化学习系列教程的个人中文翻译,该翻译是基于个人分享知识的目的进行的,欢迎交流!(This article is my personal t ...
- 强化学习之五:基于模型的强化学习(Model-based RL)
本文是对Arthur Juliani在Medium平台发布的强化学习系列教程的个人中文翻译,该翻译是基于个人分享知识的目的进行的,欢迎交流!(This article is my personal t ...
- 强化学习之三点五:上下文赌博机(Contextual Bandits)
本文是对Arthur Juliani在Medium平台发布的强化学习系列教程的个人中文翻译,该翻译是基于个人分享知识的目的进行的,欢迎交流!(This article is my personal t ...
- 强化学习之二:Q-Learning原理及表与神经网络的实现(Q-Learning with Tables and Neural Networks)
本文是对Arthur Juliani在Medium平台发布的强化学习系列教程的个人中文翻译.(This article is my personal translation for the tutor ...
- 深度强化学习(DRL)专栏(一)
目录: 1. 引言 专栏知识结构 从AlphaGo看深度强化学习 2. 强化学习基础知识 强化学习问题 马尔科夫决策过程 最优价值函数和贝尔曼方程 3. 有模型的强化学习方法 价值迭代 策略迭代 4. ...
随机推荐
- java.util.Date和java.sql.Date有什么区别?
java.util.Date包含日期和时间,而java.sql.Date只包含日期信息,而没有具体的时间信息.如果你想把时间信息存储在数据库 里,可以考虑使用Timestamp或者Date ...
- recastnavigation.Sample_TempObstacles代码注解 - rcBuildHeightfieldLayers
烘培代码在 rcBuildHeightfieldLayers 本质上是为每个tile生成高度上的不同layer 算法的关键是三层循环: for z 轴循环 for x 轴循环 for 高度span 循 ...
- USB 协议学习:000-有关概念
USB 协议学习:000-有关概念 背景 USB作为一种串行接口,应用非常广泛.掌握usb也是作为嵌入式工程师的一项具体要求. 概述 USB( Universal Serial Bus, 通用串行总线 ...
- SD中的VAE,你不能不懂
什么是VAE? VAE,即变分自编码器(Variational Autoencoder),是一种生成模型,它通过学习输入数据的潜在表示来重构输入数据. 在Stable Diffusion 1.4 或 ...
- mysql求同比环比
-- 参考:SQL计算月环比.月同比_路易吃泡面的博客-CSDN博客 -- mysql同比环比 drop table if EXISTS ordertable; create table ordert ...
- 推荐王牌远程桌面软件Getscreen,所有的远程桌面软件中使用最简单的一个
今天要推荐的远程桌面软件就是这款叫Getscreen的,推荐理由挺简单: 简单易用:只需要两步就能轻松连上远程桌面 第一步:在需要被远程连接的机器上下载它的Agent程序并启动,点击Send获得一个链 ...
- PHP+Redis 实例【一】点赞 + 热度
前言 点赞其实是一个很有意思的功能.基本的设计思路有大致两种, 一种自然是用mysql(写了几百行的代码都还没写完,有毒)啦 数据库直接落地存储, 另外一种就是利用点赞的业务特征来扔到redis(或m ...
- [oeasy]python0140_导入_import_from_as_namespace_
导入import 回忆上次内容 上次学习了 try except 注意要点 半角冒号 缩进 输出错误信息 有错就报告 不要隐瞒 否则找不到出错位置 还可以用traceback把 系统报错信息原 ...
- AT_abc182_d 题解
洛谷链接&Atcoder 链接 本篇题解为此题较简单做法及较少码量,并且码风优良,请放心阅读. 题目简述 从数轴的原点开始向正方向走. 第一次向前走 \(a_1\) 步,第二次向前走 \(a_ ...
- C#:SqlSugar中时间戳(TimeStamp)的使用
1.数据库建表 CREATE TABLE dbo.Test ( tId INT IDENTITY NOT NULL , tName NVARCHAR (20) NOT NULL , tSalary D ...