baselines算法库baselines/common/input.py模块代码:

import numpy as np
import tensorflow as tf
from gym.spaces import Discrete, Box, MultiDiscrete def observation_placeholder(ob_space, batch_size=None, name='Ob'):
'''
Create placeholder to feed observations into of the size appropriate to the observation space Parameters:
---------- ob_space: gym.Space observation space batch_size: int size of the batch to be fed into input. Can be left None in most cases. name: str name of the placeholder Returns:
------- tensorflow placeholder tensor
''' assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \
'Can only deal with Discrete and Box observation spaces for now' dtype = ob_space.dtype
if dtype == np.int8:
dtype = np.uint8 return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name) def observation_input(ob_space, batch_size=None, name='Ob'):
'''
Create placeholder to feed observations into of the size appropriate to the observation space, and add input
encoder of the appropriate type.
''' placeholder = observation_placeholder(ob_space, batch_size, name)
return placeholder, encode_observation(ob_space, placeholder) def encode_observation(ob_space, placeholder):
'''
Encode input in the way that is appropriate to the observation space Parameters:
---------- ob_space: gym.Space observation space placeholder: tf.placeholder observation input placeholder
'''
if isinstance(ob_space, Discrete):
return tf.to_float(tf.one_hot(placeholder, ob_space.n))
elif isinstance(ob_space, Box):
return tf.to_float(placeholder)
elif isinstance(ob_space, MultiDiscrete):
placeholder = tf.cast(placeholder, tf.int32)
one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
return tf.concat(one_hots, axis=-1)
else:
raise NotImplementedError

可以看到input.py模块中一共有三个函数,其中只有一个函数对外提供服务,也就是 observation_input 。

可以看到observation_placeholder函数和encode_observation函数都已经被observation_input函数包装到了一起。

在observation_placeholder函数中根据传入的 env.observation_space变量即可生成对应shape的tf.placeholder变量:

return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name)

也就是说observation_input函数中的placeholder已经是tf.placeholder类型的了。

encoder_observation 函数根据gym.spaces.observation_space的类型对tf.placeholder进行reshape操作,而tf.placeholder已经是TensorFlow的tensor变量,因此这里对tf.placeholder的操作都是在图中的构建操作,属于TensorFlow的操作。

根据encoder_observation中的代码:

我们可以知道对placeholder的reshape操作主要是对gym.space.observation_space属于Discrete和MultiDiscrete类型进行的。

如果传入的gym.space.observation_space为Discrete类型则对其对应的placeholder进行  tf.one_hot  操作,即:

tf.one_hot(placeholder, ob_space.n)

如果传入的gym.space.observation_space为MultiDiscrete类型则对其对应的placeholder中的每个Discrete进行  tf.one_hot  操作然后在concat拼接,即:

one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
return tf.concat(one_hots, axis=-1)

其中,如果gym.space.observation_space为Discrete,则其observation空间的大小为env.observation_space.n 。

其中,如果gym.space.observation_space为MultiDiscrete,则其包含的第i个Discrete对应的observation空间的大小为env.observation_space.nvec[i] 。

=======================================

observation_input函数返回的两个变量,其中placeholder是为了给网络feed数据的,而对placeholder进行reshape后的变量是为了方便后面构建神经网络的。

可以知道在gym中gym.space.MultiDiscrete主要是任天堂的游戏使用的,Nintendo Game Controller 。

import gym

env_space=gym.spaces.MultiDiscrete([ 5, 2, 2 ])

for i in range(env_space.shape[-1]):
print(env_space.nvec[i])

可以看到MultiDiscrete中的每个Discrete的空间大小使用nvec对应的索引查询。

一般,gym的常见observation空间类型有:

gym.space.Box

gym.spaces.Discrete

gym.spaces.MultiDiscrete

gym.spaces.MultiBinary

gym.spaces.Tuple

gym.spaces.Dict

其中:

gym.spaces.Box

gym.spaces.Discrete

gym.spaces.MultiDiscrete

gym.spaces.MultiBinary

对应的observation均为np.array类型,在这里:

gym.spaces.Discrete

gym.spaces.MultiDiscrete

由于状态空间可以进行one_hot处理以便于后面的计算,因此在common/input.py模块中的obervation_input函数对这样的情况进行处理,对原始的placeholder进行one_hot处理。

由于gym.spaces.MultiBinary没法进行one_hot操作,而observation_space属于类型:gym.spaces.Tuple和gym.spaces.Dict本身已经在baselines库中的其他模块被处理,由于在baselines库中已经对gym的env的observation进行了包装处理,所以可以保证在env.step和env.reset后获得的observation一定是np.arrray类型的,也就是说gym.spaces.Dict和gym.spaces.Tuple类型已经被处理了,feed给神经网络的observation只能是:gym.spaces.Box、gym.spaces.Discrete、gym.spaces.MultiDiscrete、gym.spaces.MultiBinary类型。

查看了一下baselines库的源码对env.observation的处理的代码,该部分代码在common/vec_env/util.py代码中,在该代码中如果observation_space如果是 gym.spaces.Dict和gym.spaces.Tuple则均转为gym.spaces.Dict类型,其中gym.spaces.Tuple转为gym.spaces.Dict时key用从0开始的数字代替,也就是说baselines库中没有对gym.spaces.Dict和gym.spaces.Tuple类型的observation进行过多的处理,也就是说如果原生的env环境为gym.spaces.Dict和gym.spaces.Tuple则传给算法模块的env类型也只能是gym.spaces.Dict类型,这时如果对这样的observation生成的gym.spaces.Dict类型的observation进行placeholder操作则会报错:

或者说在deepq算法中baselines库只能接收的observation类型只可以为gym.spaces.Discrete、gym.spaces.Box、gym.spaces.MultiDiscrete 。

===============================================

baselines算法库baselines/common/input.py模块分析的更多相关文章

  1. 图像滤镜艺术---ZPhotoEngine超级算法库

    原文:图像滤镜艺术---ZPhotoEngine超级算法库 一直以来,都有个想法,想要做一个属于自己的图像算法库,这个想法,在经过了几个月的努力之后,终于诞生了,这就是ZPhotoEngine算法库. ...

  2. snowland-smx密码算法库

    snowland-smx密码算法库 一.snowland-smx密码算法库的介绍 snowland-smx是python实现的国密套件,对标python实现的gmssl,包含国密SM2,SM3,SM4 ...

  3. scikit-learn 支持向量机算法库使用小结

    之前通过一个系列对支持向量机(以下简称SVM)算法的原理做了一个总结,本文从实践的角度对scikit-learn SVM算法库的使用做一个小结.scikit-learn SVM算法库封装了libsvm ...

  4. 【Python】【Web.py】详细解读Python的web.py框架下的application.py模块

    详细解读Python的web.py框架下的application.py模块   这篇文章主要介绍了Python的web.py框架下的application.py模块,作者深入分析了web.py的源码, ...

  5. 第三百零六节,Django框架,models.py模块,数据库操作——创建表、数据类型、索引、admin后台,补充Django目录说明以及全局配置文件配置

    Django框架,models.py模块,数据库操作——创建表.数据类型.索引.admin后台,补充Django目录说明以及全局配置文件配置 数据库配置 django默认支持sqlite,mysql, ...

  6. 使用织梦开源的分词算法库编写的YII获取分词扩展

    在编辑文章中,很多时候都需要自动根据文章内容获取关键字的功能,因此,本文主要是说明如何在yii中使用织梦开源的分词算法编写一个独立的扩展,可以在不同的模块中使用,步骤如下: 1 到这里下载其他朋友整理 ...

  7. 四 Django框架,models.py模块,数据库操作——创建表、数据类型、索引、admin后台,补充Django目录说明以及全局配置文件配置

    Django框架,models.py模块,数据库操作——创建表.数据类型.索引.admin后台,补充Django目录说明以及全局配置文件配置 数据库配置 django默认支持sqlite,mysql, ...

  8. mahout算法库(四)

    mahout算法库 分为三大块 1.聚类算法 2.协同过滤算法(一般用于推荐) 协同过滤算法也可以称为推荐算法!!! 3.分类算法 算法类 算法名 中文名 分类算法               Log ...

  9. 操作MySQL-数据库的安装及Pycharm模块的导入

    操作MySQL-数据库的安装及Pycharm模块的导入 1.基于pyCharm开发环境,在CMD控制台输入依次输入以下步骤: (1)pip3 install PyMySQL  < 安装 PyMy ...

  10. Magicodes.Pay,打造开箱即用的统一支付库,已提供ABP模块封装

    Magicodes.Pay,打造开箱即用的统一支付库,已提供ABP模块封装 简介 Magicodes.Pay,是心莱科技团队提供的统一支付库,相关库均使用.NET标准库编写,支持.NET Framew ...

随机推荐

  1. 流程控制之case

    1.case语句作用 case和if一样,都是用于处理多分支的条件判断 但是在条件较多的情况,if嵌套太多就不够简洁了 case语句就更简洁和规范了 2.case用法参考 常见用法就是如根据用户输入的 ...

  2. Vue学习:20.综合案例-商品列表

    学而时用之,方能融会贯通! 实例:商品列表 实现功能 要求表格组件支持动态数据渲染.自定义表头和主体.标签组件需要双击显示输入框并获得焦点,可以编辑标签信息. 思路 首先是表格组件,动态渲染需要使用组 ...

  3. 使用spark-sql处理Doris大表关联

    背景 最近项目上有一个需求,需要将两张表(A表和B表)的数据进行关联并回写入其中一张表(A表),两张表都是分区表,但是关联条件不包括分区字段. 分析过程 方案一 最朴素的想法,直接关联执行,全表关联, ...

  4. 关于 Jupyter Nbconvert 自定义 LaTeX 模板,中文兼容与格式设置,从 Notebook 构建 LaTeX PDF 文档

    目录 为什么会有这篇随笔的内容? 简述一下我遇到的问题 Nbconvert 转换 .ipynb 文件的基本方法 Jupyter Nbconvert 构建中文 \(\LaTeX\) 文档的痛点 Jupy ...

  5. IS-IS总结

    IS-IS     管理距离115     ISIS是链路状态协议     封装在数据链路层,所以没有协议号     使用SPF算法计算最短路径     没有骨干区的概念     使用IIH(ISIS ...

  6. 一招解决github访问慢的问题

    ​ 之前我在网上搜过解决办法,其中一个是修改 hosts 文件,但是效果不太理想.我在这里给大家推荐github上的一个开源项目:FastGithub .用了这个之后,效果就比较理想了,次次都能访问到 ...

  7. PyTorch程序练习(一):PyTorch实现CIFAR-10多分类

    一.准备数据 代码 import torchvision import torchvision.transforms as transforms from torch.utils.data impor ...

  8. OPC 数据采集 解决方案

    笔者计划从此篇博客开始,详细介绍OPC数据采集采集过程.包括常用组态软件介绍,数据接入,OPC接入过程,常用OPC数据接入与处理全流程范例,分享相关案例Demo. 因为分享的都是个人实际工作经验中的 ...

  9. java开发webservice报Service(URL, QName, WebServiceFeature[]) is undefined错误的解决方法

    Description Resource Path Location TypeThe constructor Service(URL, QName, WebServiceFeature[]) is u ...

  10. ubuntu16.04 安装 eclips c/c++

    前言 最近需要在ubuntu16上使用eclips编译c,尝试了apt安装和官网最新包安装甚至应用商店安装,效果都不太理想,现在把我的安装方法记录一下. 正文 !!!前提,已经自己配置好了java8的 ...