baselines算法库baselines/common/input.py模块分析
baselines算法库baselines/common/input.py模块代码:
import numpy as np
import tensorflow as tf
from gym.spaces import Discrete, Box, MultiDiscrete def observation_placeholder(ob_space, batch_size=None, name='Ob'):
'''
Create placeholder to feed observations into of the size appropriate to the observation space Parameters:
---------- ob_space: gym.Space observation space batch_size: int size of the batch to be fed into input. Can be left None in most cases. name: str name of the placeholder Returns:
------- tensorflow placeholder tensor
''' assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \
'Can only deal with Discrete and Box observation spaces for now' dtype = ob_space.dtype
if dtype == np.int8:
dtype = np.uint8 return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name) def observation_input(ob_space, batch_size=None, name='Ob'):
'''
Create placeholder to feed observations into of the size appropriate to the observation space, and add input
encoder of the appropriate type.
''' placeholder = observation_placeholder(ob_space, batch_size, name)
return placeholder, encode_observation(ob_space, placeholder) def encode_observation(ob_space, placeholder):
'''
Encode input in the way that is appropriate to the observation space Parameters:
---------- ob_space: gym.Space observation space placeholder: tf.placeholder observation input placeholder
'''
if isinstance(ob_space, Discrete):
return tf.to_float(tf.one_hot(placeholder, ob_space.n))
elif isinstance(ob_space, Box):
return tf.to_float(placeholder)
elif isinstance(ob_space, MultiDiscrete):
placeholder = tf.cast(placeholder, tf.int32)
one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
return tf.concat(one_hots, axis=-1)
else:
raise NotImplementedError
可以看到input.py模块中一共有三个函数,其中只有一个函数对外提供服务,也就是 observation_input 。

可以看到observation_placeholder函数和encode_observation函数都已经被observation_input函数包装到了一起。
在observation_placeholder函数中根据传入的 env.observation_space变量即可生成对应shape的tf.placeholder变量:
return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name)
也就是说observation_input函数中的placeholder已经是tf.placeholder类型的了。
encoder_observation 函数根据gym.spaces.observation_space的类型对tf.placeholder进行reshape操作,而tf.placeholder已经是TensorFlow的tensor变量,因此这里对tf.placeholder的操作都是在图中的构建操作,属于TensorFlow的操作。
根据encoder_observation中的代码:

我们可以知道对placeholder的reshape操作主要是对gym.space.observation_space属于Discrete和MultiDiscrete类型进行的。
如果传入的gym.space.observation_space为Discrete类型则对其对应的placeholder进行 tf.one_hot 操作,即:
tf.one_hot(placeholder, ob_space.n)
如果传入的gym.space.observation_space为MultiDiscrete类型则对其对应的placeholder中的每个Discrete进行 tf.one_hot 操作然后在concat拼接,即:
one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
return tf.concat(one_hots, axis=-1)
其中,如果gym.space.observation_space为Discrete,则其observation空间的大小为env.observation_space.n 。
其中,如果gym.space.observation_space为MultiDiscrete,则其包含的第i个Discrete对应的observation空间的大小为env.observation_space.nvec[i] 。
=======================================
observation_input函数返回的两个变量,其中placeholder是为了给网络feed数据的,而对placeholder进行reshape后的变量是为了方便后面构建神经网络的。


可以知道在gym中gym.space.MultiDiscrete主要是任天堂的游戏使用的,Nintendo Game Controller 。
import gym env_space=gym.spaces.MultiDiscrete([ 5, 2, 2 ]) for i in range(env_space.shape[-1]):
print(env_space.nvec[i])

可以看到MultiDiscrete中的每个Discrete的空间大小使用nvec对应的索引查询。
一般,gym的常见observation空间类型有:
gym.space.Box
gym.spaces.Discrete
gym.spaces.MultiDiscrete
gym.spaces.MultiBinary
gym.spaces.Tuple
gym.spaces.Dict
其中:
gym.spaces.Box
gym.spaces.Discrete
gym.spaces.MultiDiscrete
gym.spaces.MultiBinary
对应的observation均为np.array类型,在这里:
gym.spaces.Discrete
gym.spaces.MultiDiscrete
由于状态空间可以进行one_hot处理以便于后面的计算,因此在common/input.py模块中的obervation_input函数对这样的情况进行处理,对原始的placeholder进行one_hot处理。
由于gym.spaces.MultiBinary没法进行one_hot操作,而observation_space属于类型:gym.spaces.Tuple和gym.spaces.Dict本身已经在baselines库中的其他模块被处理,由于在baselines库中已经对gym的env的observation进行了包装处理,所以可以保证在env.step和env.reset后获得的observation一定是np.arrray类型的,也就是说gym.spaces.Dict和gym.spaces.Tuple类型已经被处理了,feed给神经网络的observation只能是:gym.spaces.Box、gym.spaces.Discrete、gym.spaces.MultiDiscrete、gym.spaces.MultiBinary类型。
查看了一下baselines库的源码对env.observation的处理的代码,该部分代码在common/vec_env/util.py代码中,在该代码中如果observation_space如果是 gym.spaces.Dict和gym.spaces.Tuple则均转为gym.spaces.Dict类型,其中gym.spaces.Tuple转为gym.spaces.Dict时key用从0开始的数字代替,也就是说baselines库中没有对gym.spaces.Dict和gym.spaces.Tuple类型的observation进行过多的处理,也就是说如果原生的env环境为gym.spaces.Dict和gym.spaces.Tuple则传给算法模块的env类型也只能是gym.spaces.Dict类型,这时如果对这样的observation生成的gym.spaces.Dict类型的observation进行placeholder操作则会报错:

或者说在deepq算法中baselines库只能接收的observation类型只可以为gym.spaces.Discrete、gym.spaces.Box、gym.spaces.MultiDiscrete 。
===============================================
baselines算法库baselines/common/input.py模块分析的更多相关文章
- 图像滤镜艺术---ZPhotoEngine超级算法库
原文:图像滤镜艺术---ZPhotoEngine超级算法库 一直以来,都有个想法,想要做一个属于自己的图像算法库,这个想法,在经过了几个月的努力之后,终于诞生了,这就是ZPhotoEngine算法库. ...
- snowland-smx密码算法库
snowland-smx密码算法库 一.snowland-smx密码算法库的介绍 snowland-smx是python实现的国密套件,对标python实现的gmssl,包含国密SM2,SM3,SM4 ...
- scikit-learn 支持向量机算法库使用小结
之前通过一个系列对支持向量机(以下简称SVM)算法的原理做了一个总结,本文从实践的角度对scikit-learn SVM算法库的使用做一个小结.scikit-learn SVM算法库封装了libsvm ...
- 【Python】【Web.py】详细解读Python的web.py框架下的application.py模块
详细解读Python的web.py框架下的application.py模块 这篇文章主要介绍了Python的web.py框架下的application.py模块,作者深入分析了web.py的源码, ...
- 第三百零六节,Django框架,models.py模块,数据库操作——创建表、数据类型、索引、admin后台,补充Django目录说明以及全局配置文件配置
Django框架,models.py模块,数据库操作——创建表.数据类型.索引.admin后台,补充Django目录说明以及全局配置文件配置 数据库配置 django默认支持sqlite,mysql, ...
- 使用织梦开源的分词算法库编写的YII获取分词扩展
在编辑文章中,很多时候都需要自动根据文章内容获取关键字的功能,因此,本文主要是说明如何在yii中使用织梦开源的分词算法编写一个独立的扩展,可以在不同的模块中使用,步骤如下: 1 到这里下载其他朋友整理 ...
- 四 Django框架,models.py模块,数据库操作——创建表、数据类型、索引、admin后台,补充Django目录说明以及全局配置文件配置
Django框架,models.py模块,数据库操作——创建表.数据类型.索引.admin后台,补充Django目录说明以及全局配置文件配置 数据库配置 django默认支持sqlite,mysql, ...
- mahout算法库(四)
mahout算法库 分为三大块 1.聚类算法 2.协同过滤算法(一般用于推荐) 协同过滤算法也可以称为推荐算法!!! 3.分类算法 算法类 算法名 中文名 分类算法 Log ...
- 操作MySQL-数据库的安装及Pycharm模块的导入
操作MySQL-数据库的安装及Pycharm模块的导入 1.基于pyCharm开发环境,在CMD控制台输入依次输入以下步骤: (1)pip3 install PyMySQL < 安装 PyMy ...
- Magicodes.Pay,打造开箱即用的统一支付库,已提供ABP模块封装
Magicodes.Pay,打造开箱即用的统一支付库,已提供ABP模块封装 简介 Magicodes.Pay,是心莱科技团队提供的统一支付库,相关库均使用.NET标准库编写,支持.NET Framew ...
随机推荐
- 串口收发UART(Verilog HDL)
UART(Universal Asynchronous Receiver Transmitter,通用异步收发器)是一种异步串行通信协议,主要用于计算机和嵌入式系统之间的数据交换. 实现UART通信的 ...
- npm 发布自己的组件库
npm 发布组件库步骤 第一步:注册 npm 账号 第二步:编写自己的组件库 第三部:编写 package.json 可以通过命令生成 npm init { "name": &qu ...
- Excel poi 设置单元格格式 发现不可读内容 已修复的记录: /xl/worksheets/sheet1.xml 部分的问题(巨坑)
Excel poi 设置单元格格式 发现不可读内容 已修复的记录: /xl/worksheets/sheet1.xml 部分的问题(巨坑) 1.先设置值,后设置样式. 正确的是:先设置样式,后设置值. ...
- 高德地图查询结果返回INVALID_USER_IP错误解决
高德地图查询结果返回INVALID_USER_IP错误解决 方法是添加白名单.IP白名单出错,发送请求的服务器IP不在IP白名单内开发者在LBS官网控制台设置的IP白名单不正确.白名单中未添加对应服务 ...
- MoneyPrinterPlus:AI自动短视频生成工具,详细使用教程
MoneyPrinterPlus是一款使用AI大模型技术,一键批量生成各类短视频,自动批量混剪短视频,自动把视频发布到抖音,快手,小红书,视频号上的轻松赚钱工具. 之前有出过一期基本的介绍,但是后台收 ...
- SpringBoot的Security和OAuth2的使用
创建项目 先创建一个spring项目. 然后编写pom文件如下,引入spring-boot-starter-security,我这里使用的spring boot是2.4.2,这里使用使用spring- ...
- JS的JSON.parse问题
这个问题,已经有非常多人说过,而且由来已久. 大家都提供了不少的解决方法,但是都不够彻底. 一)现在是什么情况 1.使用SpringMvc+ModelAndView+jsp传递值 由于业务需要,通过m ...
- Linux 特权 SUID/SGID 的详解
导航 0 前言 1 权限匹配流程 2 五种身份变化 3 有效用户/组 4 特权对 Shell 脚本无效 5 Sudo 与 SUID/SGID 的优先级 6 SUID.SGID.Sticky 各自的功能 ...
- 卷积神经网络中nn.Conv2d()和nn.MaxPool2d()以及卷积神经网络实现minist数据集分类
卷积神经网络中nn.Conv2d()和nn.MaxPool2d() 卷积神经网络之Pythorch实现: nn.Conv2d()就是PyTorch中的卷积模块 参数列表 参数 作用 in_channe ...
- 【ClickHouse】0:clickhouse学习3之时间日期函数
官方文档: https://clickhouse.tech/docs/zh/sql-reference/functions/date-time-functions/ 常用的clickhouse时间函数 ...