baselines算法库baselines/common/input.py模块分析
baselines算法库baselines/common/input.py模块代码:
import numpy as np
import tensorflow as tf
from gym.spaces import Discrete, Box, MultiDiscrete def observation_placeholder(ob_space, batch_size=None, name='Ob'):
'''
Create placeholder to feed observations into of the size appropriate to the observation space Parameters:
---------- ob_space: gym.Space observation space batch_size: int size of the batch to be fed into input. Can be left None in most cases. name: str name of the placeholder Returns:
------- tensorflow placeholder tensor
''' assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \
'Can only deal with Discrete and Box observation spaces for now' dtype = ob_space.dtype
if dtype == np.int8:
dtype = np.uint8 return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name) def observation_input(ob_space, batch_size=None, name='Ob'):
'''
Create placeholder to feed observations into of the size appropriate to the observation space, and add input
encoder of the appropriate type.
''' placeholder = observation_placeholder(ob_space, batch_size, name)
return placeholder, encode_observation(ob_space, placeholder) def encode_observation(ob_space, placeholder):
'''
Encode input in the way that is appropriate to the observation space Parameters:
---------- ob_space: gym.Space observation space placeholder: tf.placeholder observation input placeholder
'''
if isinstance(ob_space, Discrete):
return tf.to_float(tf.one_hot(placeholder, ob_space.n))
elif isinstance(ob_space, Box):
return tf.to_float(placeholder)
elif isinstance(ob_space, MultiDiscrete):
placeholder = tf.cast(placeholder, tf.int32)
one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
return tf.concat(one_hots, axis=-1)
else:
raise NotImplementedError
可以看到input.py模块中一共有三个函数,其中只有一个函数对外提供服务,也就是 observation_input 。

可以看到observation_placeholder函数和encode_observation函数都已经被observation_input函数包装到了一起。
在observation_placeholder函数中根据传入的 env.observation_space变量即可生成对应shape的tf.placeholder变量:
return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name)
也就是说observation_input函数中的placeholder已经是tf.placeholder类型的了。
encoder_observation 函数根据gym.spaces.observation_space的类型对tf.placeholder进行reshape操作,而tf.placeholder已经是TensorFlow的tensor变量,因此这里对tf.placeholder的操作都是在图中的构建操作,属于TensorFlow的操作。
根据encoder_observation中的代码:

我们可以知道对placeholder的reshape操作主要是对gym.space.observation_space属于Discrete和MultiDiscrete类型进行的。
如果传入的gym.space.observation_space为Discrete类型则对其对应的placeholder进行 tf.one_hot 操作,即:
tf.one_hot(placeholder, ob_space.n)
如果传入的gym.space.observation_space为MultiDiscrete类型则对其对应的placeholder中的每个Discrete进行 tf.one_hot 操作然后在concat拼接,即:
one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
return tf.concat(one_hots, axis=-1)
其中,如果gym.space.observation_space为Discrete,则其observation空间的大小为env.observation_space.n 。
其中,如果gym.space.observation_space为MultiDiscrete,则其包含的第i个Discrete对应的observation空间的大小为env.observation_space.nvec[i] 。
=======================================
observation_input函数返回的两个变量,其中placeholder是为了给网络feed数据的,而对placeholder进行reshape后的变量是为了方便后面构建神经网络的。


可以知道在gym中gym.space.MultiDiscrete主要是任天堂的游戏使用的,Nintendo Game Controller 。
import gym env_space=gym.spaces.MultiDiscrete([ 5, 2, 2 ]) for i in range(env_space.shape[-1]):
print(env_space.nvec[i])

可以看到MultiDiscrete中的每个Discrete的空间大小使用nvec对应的索引查询。
一般,gym的常见observation空间类型有:
gym.space.Box
gym.spaces.Discrete
gym.spaces.MultiDiscrete
gym.spaces.MultiBinary
gym.spaces.Tuple
gym.spaces.Dict
其中:
gym.spaces.Box
gym.spaces.Discrete
gym.spaces.MultiDiscrete
gym.spaces.MultiBinary
对应的observation均为np.array类型,在这里:
gym.spaces.Discrete
gym.spaces.MultiDiscrete
由于状态空间可以进行one_hot处理以便于后面的计算,因此在common/input.py模块中的obervation_input函数对这样的情况进行处理,对原始的placeholder进行one_hot处理。
由于gym.spaces.MultiBinary没法进行one_hot操作,而observation_space属于类型:gym.spaces.Tuple和gym.spaces.Dict本身已经在baselines库中的其他模块被处理,由于在baselines库中已经对gym的env的observation进行了包装处理,所以可以保证在env.step和env.reset后获得的observation一定是np.arrray类型的,也就是说gym.spaces.Dict和gym.spaces.Tuple类型已经被处理了,feed给神经网络的observation只能是:gym.spaces.Box、gym.spaces.Discrete、gym.spaces.MultiDiscrete、gym.spaces.MultiBinary类型。
查看了一下baselines库的源码对env.observation的处理的代码,该部分代码在common/vec_env/util.py代码中,在该代码中如果observation_space如果是 gym.spaces.Dict和gym.spaces.Tuple则均转为gym.spaces.Dict类型,其中gym.spaces.Tuple转为gym.spaces.Dict时key用从0开始的数字代替,也就是说baselines库中没有对gym.spaces.Dict和gym.spaces.Tuple类型的observation进行过多的处理,也就是说如果原生的env环境为gym.spaces.Dict和gym.spaces.Tuple则传给算法模块的env类型也只能是gym.spaces.Dict类型,这时如果对这样的observation生成的gym.spaces.Dict类型的observation进行placeholder操作则会报错:

或者说在deepq算法中baselines库只能接收的observation类型只可以为gym.spaces.Discrete、gym.spaces.Box、gym.spaces.MultiDiscrete 。
===============================================
baselines算法库baselines/common/input.py模块分析的更多相关文章
- 图像滤镜艺术---ZPhotoEngine超级算法库
原文:图像滤镜艺术---ZPhotoEngine超级算法库 一直以来,都有个想法,想要做一个属于自己的图像算法库,这个想法,在经过了几个月的努力之后,终于诞生了,这就是ZPhotoEngine算法库. ...
- snowland-smx密码算法库
snowland-smx密码算法库 一.snowland-smx密码算法库的介绍 snowland-smx是python实现的国密套件,对标python实现的gmssl,包含国密SM2,SM3,SM4 ...
- scikit-learn 支持向量机算法库使用小结
之前通过一个系列对支持向量机(以下简称SVM)算法的原理做了一个总结,本文从实践的角度对scikit-learn SVM算法库的使用做一个小结.scikit-learn SVM算法库封装了libsvm ...
- 【Python】【Web.py】详细解读Python的web.py框架下的application.py模块
详细解读Python的web.py框架下的application.py模块 这篇文章主要介绍了Python的web.py框架下的application.py模块,作者深入分析了web.py的源码, ...
- 第三百零六节,Django框架,models.py模块,数据库操作——创建表、数据类型、索引、admin后台,补充Django目录说明以及全局配置文件配置
Django框架,models.py模块,数据库操作——创建表.数据类型.索引.admin后台,补充Django目录说明以及全局配置文件配置 数据库配置 django默认支持sqlite,mysql, ...
- 使用织梦开源的分词算法库编写的YII获取分词扩展
在编辑文章中,很多时候都需要自动根据文章内容获取关键字的功能,因此,本文主要是说明如何在yii中使用织梦开源的分词算法编写一个独立的扩展,可以在不同的模块中使用,步骤如下: 1 到这里下载其他朋友整理 ...
- 四 Django框架,models.py模块,数据库操作——创建表、数据类型、索引、admin后台,补充Django目录说明以及全局配置文件配置
Django框架,models.py模块,数据库操作——创建表.数据类型.索引.admin后台,补充Django目录说明以及全局配置文件配置 数据库配置 django默认支持sqlite,mysql, ...
- mahout算法库(四)
mahout算法库 分为三大块 1.聚类算法 2.协同过滤算法(一般用于推荐) 协同过滤算法也可以称为推荐算法!!! 3.分类算法 算法类 算法名 中文名 分类算法 Log ...
- 操作MySQL-数据库的安装及Pycharm模块的导入
操作MySQL-数据库的安装及Pycharm模块的导入 1.基于pyCharm开发环境,在CMD控制台输入依次输入以下步骤: (1)pip3 install PyMySQL < 安装 PyMy ...
- Magicodes.Pay,打造开箱即用的统一支付库,已提供ABP模块封装
Magicodes.Pay,打造开箱即用的统一支付库,已提供ABP模块封装 简介 Magicodes.Pay,是心莱科技团队提供的统一支付库,相关库均使用.NET标准库编写,支持.NET Framew ...
随机推荐
- npm 发布自己的组件库
npm 发布组件库步骤 第一步:注册 npm 账号 第二步:编写自己的组件库 第三部:编写 package.json 可以通过命令生成 npm init { "name": &qu ...
- C# DateTime日期字段转中文文字
public static String ToChineseYearAndMonth(this DateTime dt) { string[] chineseNumbers = { "零&q ...
- 深入了解身份认证和授权机制,看看API请求到底发生了什么?
前段时间写了一篇基于.NetCore环境使用IdentityServer4为API接口鉴权的文章,更多的是从快速上手的角度描述了IdentityServer4的使用.后续使用过程中,自己有了一些其他想 ...
- MySQL常见的后端面试题,你会几道?
为什么分库分表 单表数据量过大,会出现慢查询,所以需要水平分表 可以把低频.高频的字段分开为多个表,低频的表作为附加表,且逻辑更加清晰,性能更优 随着系统的业务模块的增多,放到单库会增加其复杂度,逻辑 ...
- 记录用C#写折半查找算法实现
折半查找算法 前言 最近要考试了,重新回顾一下之前学的算法,今天是折半查找,它的平均比较次数是Log2 n 思想 给定一个有序数组A[0..n-1],和查找值K,返回K在A中的下标. 折半查找需要指定 ...
- 【规范】Git分支管理,看看我司是咋整的
前言 缘由 Git分支管理好,走到哪里都是宝 事情起因: 最近翻看博客中小伙伴评论时,发现文章[规范]看看人家Git提交描述,那叫一个规矩一条回复: 本狗亲测在我司中使用规范的好处,遂把我司的Git分 ...
- 2个qubit的量子门
量子计算机就是基于单qubit门和双qubit门的,再多的量子操作都是基于这两种门.双qubit门比单qubit门难理解得多,不过也重要得多.它可以用来创建纠缠,没有纠缠,量子机就不可能有量子霸权. ...
- C#事件总结
前言:C#的事件也是一项非常关键的技术,必须要深刻的理解,本质上是基于委托的: 事件模型的五个组成部分: 1.事件的拥有者-- event source,对象: 2.事件的成员--event,成员: ...
- 全网最适合入门的面向对象编程教程:08 类和对象的Python实现-@property装饰器:把方法包装成属性
全网最适合入门的面向对象编程教程:08 类和对象的 Python 实现-@property 装饰器:把方法包装成属性 摘要: 本文主要对@property 装饰器的基本定义.使用场景和使用方法进行了介 ...
- Mysql中where条件自动类型转换的坑
我有张表,其主键id字段为varchar(5),内容是5位随机不重复字符串表的内容大概是这样的 id name s8bk2 admin 9f0ps username 在一个方法中我查询了这张表,代码大 ...