baselines算法库baselines/common/input.py模块分析
baselines算法库baselines/common/input.py模块代码:
import numpy as np
import tensorflow as tf
from gym.spaces import Discrete, Box, MultiDiscrete def observation_placeholder(ob_space, batch_size=None, name='Ob'):
'''
Create placeholder to feed observations into of the size appropriate to the observation space Parameters:
---------- ob_space: gym.Space observation space batch_size: int size of the batch to be fed into input. Can be left None in most cases. name: str name of the placeholder Returns:
------- tensorflow placeholder tensor
''' assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \
'Can only deal with Discrete and Box observation spaces for now' dtype = ob_space.dtype
if dtype == np.int8:
dtype = np.uint8 return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name) def observation_input(ob_space, batch_size=None, name='Ob'):
'''
Create placeholder to feed observations into of the size appropriate to the observation space, and add input
encoder of the appropriate type.
''' placeholder = observation_placeholder(ob_space, batch_size, name)
return placeholder, encode_observation(ob_space, placeholder) def encode_observation(ob_space, placeholder):
'''
Encode input in the way that is appropriate to the observation space Parameters:
---------- ob_space: gym.Space observation space placeholder: tf.placeholder observation input placeholder
'''
if isinstance(ob_space, Discrete):
return tf.to_float(tf.one_hot(placeholder, ob_space.n))
elif isinstance(ob_space, Box):
return tf.to_float(placeholder)
elif isinstance(ob_space, MultiDiscrete):
placeholder = tf.cast(placeholder, tf.int32)
one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
return tf.concat(one_hots, axis=-1)
else:
raise NotImplementedError
可以看到input.py模块中一共有三个函数,其中只有一个函数对外提供服务,也就是 observation_input 。
可以看到observation_placeholder函数和encode_observation函数都已经被observation_input函数包装到了一起。
在observation_placeholder函数中根据传入的 env.observation_space变量即可生成对应shape的tf.placeholder变量:
return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name)
也就是说observation_input函数中的placeholder已经是tf.placeholder类型的了。
encoder_observation 函数根据gym.spaces.observation_space的类型对tf.placeholder进行reshape操作,而tf.placeholder已经是TensorFlow的tensor变量,因此这里对tf.placeholder的操作都是在图中的构建操作,属于TensorFlow的操作。
根据encoder_observation中的代码:
我们可以知道对placeholder的reshape操作主要是对gym.space.observation_space属于Discrete和MultiDiscrete类型进行的。
如果传入的gym.space.observation_space为Discrete类型则对其对应的placeholder进行 tf.one_hot 操作,即:
tf.one_hot(placeholder, ob_space.n)
如果传入的gym.space.observation_space为MultiDiscrete类型则对其对应的placeholder中的每个Discrete进行 tf.one_hot 操作然后在concat拼接,即:
one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
return tf.concat(one_hots, axis=-1)
其中,如果gym.space.observation_space为Discrete,则其observation空间的大小为env.observation_space.n 。
其中,如果gym.space.observation_space为MultiDiscrete,则其包含的第i个Discrete对应的observation空间的大小为env.observation_space.nvec[i] 。
=======================================
observation_input函数返回的两个变量,其中placeholder是为了给网络feed数据的,而对placeholder进行reshape后的变量是为了方便后面构建神经网络的。
可以知道在gym中gym.space.MultiDiscrete主要是任天堂的游戏使用的,Nintendo Game Controller 。
import gym env_space=gym.spaces.MultiDiscrete([ 5, 2, 2 ]) for i in range(env_space.shape[-1]):
print(env_space.nvec[i])
可以看到MultiDiscrete中的每个Discrete的空间大小使用nvec对应的索引查询。
一般,gym的常见observation空间类型有:
gym.space.Box
gym.spaces.Discrete
gym.spaces.MultiDiscrete
gym.spaces.MultiBinary
gym.spaces.Tuple
gym.spaces.Dict
其中:
gym.spaces.Box
gym.spaces.Discrete
gym.spaces.MultiDiscrete
gym.spaces.MultiBinary
对应的observation均为np.array类型,在这里:
gym.spaces.Discrete
gym.spaces.MultiDiscrete
由于状态空间可以进行one_hot处理以便于后面的计算,因此在common/input.py模块中的obervation_input函数对这样的情况进行处理,对原始的placeholder进行one_hot处理。
由于gym.spaces.MultiBinary没法进行one_hot操作,而observation_space属于类型:gym.spaces.Tuple和gym.spaces.Dict本身已经在baselines库中的其他模块被处理,由于在baselines库中已经对gym的env的observation进行了包装处理,所以可以保证在env.step和env.reset后获得的observation一定是np.arrray类型的,也就是说gym.spaces.Dict和gym.spaces.Tuple类型已经被处理了,feed给神经网络的observation只能是:gym.spaces.Box、gym.spaces.Discrete、gym.spaces.MultiDiscrete、gym.spaces.MultiBinary类型。
查看了一下baselines库的源码对env.observation的处理的代码,该部分代码在common/vec_env/util.py代码中,在该代码中如果observation_space如果是 gym.spaces.Dict和gym.spaces.Tuple则均转为gym.spaces.Dict类型,其中gym.spaces.Tuple转为gym.spaces.Dict时key用从0开始的数字代替,也就是说baselines库中没有对gym.spaces.Dict和gym.spaces.Tuple类型的observation进行过多的处理,也就是说如果原生的env环境为gym.spaces.Dict和gym.spaces.Tuple则传给算法模块的env类型也只能是gym.spaces.Dict类型,这时如果对这样的observation生成的gym.spaces.Dict类型的observation进行placeholder操作则会报错:
或者说在deepq算法中baselines库只能接收的observation类型只可以为gym.spaces.Discrete、gym.spaces.Box、gym.spaces.MultiDiscrete 。
===============================================
baselines算法库baselines/common/input.py模块分析的更多相关文章
- 图像滤镜艺术---ZPhotoEngine超级算法库
原文:图像滤镜艺术---ZPhotoEngine超级算法库 一直以来,都有个想法,想要做一个属于自己的图像算法库,这个想法,在经过了几个月的努力之后,终于诞生了,这就是ZPhotoEngine算法库. ...
- snowland-smx密码算法库
snowland-smx密码算法库 一.snowland-smx密码算法库的介绍 snowland-smx是python实现的国密套件,对标python实现的gmssl,包含国密SM2,SM3,SM4 ...
- scikit-learn 支持向量机算法库使用小结
之前通过一个系列对支持向量机(以下简称SVM)算法的原理做了一个总结,本文从实践的角度对scikit-learn SVM算法库的使用做一个小结.scikit-learn SVM算法库封装了libsvm ...
- 【Python】【Web.py】详细解读Python的web.py框架下的application.py模块
详细解读Python的web.py框架下的application.py模块 这篇文章主要介绍了Python的web.py框架下的application.py模块,作者深入分析了web.py的源码, ...
- 第三百零六节,Django框架,models.py模块,数据库操作——创建表、数据类型、索引、admin后台,补充Django目录说明以及全局配置文件配置
Django框架,models.py模块,数据库操作——创建表.数据类型.索引.admin后台,补充Django目录说明以及全局配置文件配置 数据库配置 django默认支持sqlite,mysql, ...
- 使用织梦开源的分词算法库编写的YII获取分词扩展
在编辑文章中,很多时候都需要自动根据文章内容获取关键字的功能,因此,本文主要是说明如何在yii中使用织梦开源的分词算法编写一个独立的扩展,可以在不同的模块中使用,步骤如下: 1 到这里下载其他朋友整理 ...
- 四 Django框架,models.py模块,数据库操作——创建表、数据类型、索引、admin后台,补充Django目录说明以及全局配置文件配置
Django框架,models.py模块,数据库操作——创建表.数据类型.索引.admin后台,补充Django目录说明以及全局配置文件配置 数据库配置 django默认支持sqlite,mysql, ...
- mahout算法库(四)
mahout算法库 分为三大块 1.聚类算法 2.协同过滤算法(一般用于推荐) 协同过滤算法也可以称为推荐算法!!! 3.分类算法 算法类 算法名 中文名 分类算法 Log ...
- 操作MySQL-数据库的安装及Pycharm模块的导入
操作MySQL-数据库的安装及Pycharm模块的导入 1.基于pyCharm开发环境,在CMD控制台输入依次输入以下步骤: (1)pip3 install PyMySQL < 安装 PyMy ...
- Magicodes.Pay,打造开箱即用的统一支付库,已提供ABP模块封装
Magicodes.Pay,打造开箱即用的统一支付库,已提供ABP模块封装 简介 Magicodes.Pay,是心莱科技团队提供的统一支付库,相关库均使用.NET标准库编写,支持.NET Framew ...
随机推荐
- js沙雕排序之睡眠排序&随机排序
1.睡眠排序,只要睡的时间多少就可以排序出来不要在乎时间多少 var arr=[4,77,741,41,142,52,244]; var sleepSort=function(arr,callback ...
- Java动态获取实现类 Class.forName(clazz).newInstance()和applicationContext.getBean, bean Map寻找方式,Java Map定义和初始化方法
Java动态获取实现类 Class.forName(clazz).newInstance()和applicationContext.getBean, bean Map寻找方式,Java Map定义和初 ...
- Timing!!!
End or Beginning "毕业",一个令人无限憧憬的具象化名词.适逢高考结束,又有一批人将奔赴更远的地方,离开他们生活了十八年的城市,在这之中亦然有着曾经的我们.但大家把 ...
- Python中multiprocessing.Pool进程池实现守护进程的方法
前言 在multiprocessing.Process中可以使用p.daemon=True将子进程p设置为守护进程. 那么在multiprocessing.Pool进程池中怎么实现这个功能呢? 什么是 ...
- 【动手学深度学习】第三章笔记:线性回归、SoftMax 回归、交叉熵损失
这章感觉没什么需要特别记住的东西,感觉忘了回来翻一翻代码就好. 3.1 线性回归 3.1.1 线性回归的基本元素 1. 线性模型 \(\boldsymbol{x}^{(i)}\) 是一个列向量,表示第 ...
- MySQL自定义函数(User Define Function)开发实例——发送TCP/UDP消息
开发背景 当数据库中某个字段的值改为特定值时,实时发送消息通知到其他系统. 实现思路 监控数据库中特定字段值的变化可以用数据库触发器实现.还需要实现一个自定义的函数,接收一个字符串参数,然后将这个字符 ...
- python_8 拆包、内置函数和高阶函数
一.查缺补漏 1. \t 子表符,用于对其二.拆包 1. 拆包:顾名思义就是将可迭代的对象如元组,列表,字符串,集合,字典,拆分出相对应的元素 2. 形式:拆包一般分两种方式,一种是以变量的方式来接收 ...
- python基础-字符串str " "
字符串的定义和操作 字符串的特性: 元素数量 支持多个 元素类型 仅字符 下标索引 支持 重复元素 支持 可修改性 不支持 数据有序 是 使用场景 一串字符的记录场景 字符串的相关操作: my_str ...
- 国产化率100%!全志科技A40i工业核心板规格书资料分享
1.核心板简介 创龙科技SOM-TLA40i是一款基于全志科技A40i处理器设计的4核ARM Cortex-A7国产工业核心板,每核主频高达1.2GHz. 核心板通过邮票孔连接方式引出CSI.TVIN ...
- 英特尔开源新等宽字体Intel One Mono,称可保护开发者视力
英特尔开源了一款面向开发者的新等宽字体 "Intel One Mono ",这是一种富有表现力的等宽字体系列,集清晰度.易读性和开发者视力保护于一体. Intel One Mono ...