baselines库中cmd_util.py模块对atari游戏的包装为什么要分成两部分并在中间加入flatten操作呢?
如题:
cmd_util.py模块中对应的代码:
可以看到不论是atari游戏还是retro游戏,在进行游戏环境包装的时候都是分成两部分的,如atari游戏,第一部分是make_atari,第二部分是wrap_deepmind,在两者之间有一个FlattenObservation操作。
通过FlattenObservation的代码可以知道,该操作是将observation的space从dict变为np.array,也就是gym.spaces.Dict变为gym.spaces.Box类型:
import numpy as np
import gym.spaces as spaces
from gym import ObservationWrapper class FlattenObservation(ObservationWrapper):
r"""Observation wrapper that flattens the observation."""
def __init__(self, env):
super(FlattenObservation, self).__init__(env) flatdim = spaces.flatdim(env.observation_space)
self.observation_space = spaces.Box(low=-float('inf'), high=float('inf'), shape=(flatdim,), dtype=np.float32) def observation(self, observation):
return spaces.flatten(self.env.observation_space, observation)
对atari游戏的两个包装方法来看:
def make_atari(env_id, max_episode_steps=None):
env = gym.make(env_id)
assert 'NoFrameskip' in env.spec.id
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps=max_episode_steps)
return env def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
"""Configure environment for DeepMind-style Atari.
"""
if episode_life:
env = EpisodicLifeEnv(env)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = WarpFrame(env)
if scale:
env = ScaledFloatFrame(env)
if clip_rewards:
env = ClipRewardEnv(env)
if frame_stack:
env = FrameStack(env, 4)
return env
make_atari部分并不对observation部分进行处理,而wrap_deepmind部分才对observation部分进行处理,因此在baselines库中对这两部分拆开并在中间进行FlattenObservation操作 ,这样以好保证在wrap_deepmind部分的操作可以直接对np.array类型的observation进行操作。
个人评价:
其实感觉这个FlattenObservation操作还是有一定欠缺的,就是对MultiDiscrete的observation,没有对observation进行one-hot操作。
而这个代码中对Discrete的observation是进行了one-hot编码,而对MultiDiscrete的observation并没有进行one-hot编码,而这个对应MultiDiscrete是否应该进行one-hot编码也是要看具体情况的,如果observation的spaces虽然属于MultiDiscrete但是它的spaces.shape的很大,也就是observation的空间维度很大,这样的话也没有必要进行one-hot编码,但是如果shape比较小,如为2,这样的,那么就有必要one-hot。
如:
import gym obs_space=gym.spaces.MultiDiscrete((3,5)) print(obs_space.shape)
print(obs_space.nvec)
可以知道如果observation的space属于上面的情况,那么不one-hot编码observation的空间编码长度为2, 如果one-hot编码后长度为8。
也就是不one-hot编码的一个observation,如:(2,3) ,one-hot编码后为(010 00100),
从这个形式上来看,好像对于MultiDsicrete的observation是否进行one-hot编码好像也没有太大的影响,或许baselines中的设置还是说的过去的。
但是这个代码中还有一个地方需要注意:
self.observation_space = spaces.Box(low=-float('inf'), high=float('inf'), shape=(flatdim,), dtype=np.float32)
从这个代码中可以看到不论observation的原始数据类型是什么,只要进行了flatten操作都会把数据类型转为np.float32,这样的操作可能导致精度损失,有可能造成空间存储变大,所以这个FlattenObservation操作是非必要不使用的,不然很可能出问题的。
或许这也是在run.py中对使用FlattenObservation操作的限制了:
可以看到在baselines中只有对observation_space属于gym.spaces.Dict的才进行FlattenObservation操作。
给出一个自己FlattenObservation操作单独写在一个文件中的代码:
import numpy as np
import gym.spaces as spaces
from gym import ObservationWrapper from gym.spaces import Box
from gym.spaces import Discrete
from gym.spaces import MultiDiscrete
from gym.spaces import MultiBinary
from gym.spaces import Tuple
from gym.spaces import Dict def flatdim(space):
if isinstance(space, Box):
return int(np.prod(space.shape))
elif isinstance(space, Discrete):
return int(space.n)
elif isinstance(space, Tuple):
return int(sum([flatdim(s) for s in space.spaces]))
elif isinstance(space, Dict):
return int(sum([flatdim(s) for s in space.spaces.values()]))
elif isinstance(space, MultiBinary):
return int(space.n)
elif isinstance(space, MultiDiscrete):
return int(np.prod(space.shape))
else:
raise NotImplementedError def flatten(space, x):
if isinstance(space, Box):
return np.asarray(x, dtype=np.float32).flatten()
elif isinstance(space, Discrete):
onehot = np.zeros(space.n, dtype=np.float32)
onehot[x] = 1.0
return onehot
elif isinstance(space, Tuple):
return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
elif isinstance(space, Dict):
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
elif isinstance(space, MultiBinary):
return np.asarray(x).flatten()
elif isinstance(space, MultiDiscrete):
return np.asarray(x).flatten()
else:
raise NotImplementedError class FlattenObs(ObservationWrapper):
r"""Observation wrapper that flattens the observation.""" def __init__(self, env):
super(FlattenObs, self).__init__(env) _flatdim = flatdim(env.observation_space)
self.observation_space = spaces.Box(low=-float('inf'), high=float('inf'), shape=(_flatdim,), dtype=np.float32) def observation(self, observation):
return flatten(self.env.observation_space, observation) if __name__ == '__main__':
import gym
FlattenObs(gym.make('Pong-v0'))
print(gym.make('Pong-v0').observation_space)
print(gym.make('Pong-v0').observation_space.dtype)
==========================================
修正一点:
在make_atari函数中的MaxAndSkipEnv,也是对observation进行操作的,换句话说,在baselines中对FlattenObservation操作的设置本身也是有一定没有表述清的,那就是atari游戏的observation,也包括retro都是图像,也就是np.array类型,本身也不需要FlattenObservation操作。
也就是说,下面的代码块应该是在一起的:
上面代码可以改为:
else:
env = gym.make(env_id, **env_kwargs) if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
env = FlattenObservation(env)
=============================================
baselines库中cmd_util.py模块对atari游戏的包装为什么要分成两部分并在中间加入flatten操作呢?的更多相关文章
- Python 3 中的json模块使用
1. 概述 JSON (JavaScript Object Notation)是一种使用广泛的轻量数据格式. Python标准库中的json模块提供了JSON数据的处理功能. Python中一种非常常 ...
- Python中的Pexpect模块的简单使用
Pexpect 是一个用来启动子程序并对其进行自动控制的 Python 模块. Pexpect 可以用来和像 ssh.ftp.passwd.telnet 等命令行程序进行自动交互.以下所有代码都是在K ...
- 在Team Foundation Server (TFS)的代码库或配置库中查找文件或代码
[update 2017.2.11] 最新版本的TFS 2017已经增加了代码搜索功能,可以参考这个链接 https://blogs.msdn.microsoft.com/visualstudioal ...
- (转)python标准库中socket模块详解
python标准库中socket模块详解 socket模块简介 原文:http://www.lybbn.cn/data/datas.php?yw=71 网络上的两个程序通过一个双向的通信连接实现数据的 ...
- 在Pycharm中导入第三方模块库(诸如:matplotlib、numpy等)
在Pycharm中导入第三方模块库 一.打开pycharm: 二.点击菜单上的“file” -> “setting”: 三.步骤二完成后出现界面如下所示.选中你的项目(比如thisyan Pro ...
- python3 中引用 HTMLTestRunner.py 模块的注意事项
HTMLTestRunner.py支持python2中运行,如果在python3.6.2中引用HTMLTestRunner.py模块,需要做一下更改: 1.更改HTMLTestRunner.py模块中 ...
- python模块中__init__.py的作用
基本概念先上结论举例解释实验一:不包含__init__.py实验二:A中包含__init__.py实验三:A.A_A中也包含__init__.py进阶基本概念概念 解释import 即导入,方式就是在 ...
- [转载]python中的sys模块(二)
#!/usr/bin/python # Filename: using_sys.py import sys print 'The command line arguments are:' for i ...
- python中常用的模块的总结
1. 模块和包 a.定义: 模块用来从逻辑上组织python代码(变量,函数,类,逻辑:实现一个功能),本质就是.py结尾的python文件.(例如:文件名:test.py,对应的模块名:test) ...
- Python中的logging模块
http://python.jobbole.com/86887/ 最近修改了项目里的logging相关功能,用到了python标准库里的logging模块,在此做一些记录.主要是从官方文档和stack ...
随机推荐
- monaco-editor 的 Language Services
我们是袋鼠云数栈 UED 团队,致力于打造优秀的一站式数据中台产品.我们始终保持工匠精神,探索前端道路,为社区积累并传播经验价值. 本文作者:修能 这是一段平平无奇的 SQL 语法 SELECT id ...
- 在线Bcrypt加密、验证工具
在线bcrypt加密,bcrypt算法是一种密码哈希算法,它是基于Blowfish加密算法改进的,能够生成安全性很高的哈希值,并且可以通过调整计算时间来提高安全性.本工具支持在线Bcrypt加密及验证 ...
- uniapp windows 上架 apple store
香蕉云 蒲公英 ios上架助手iOS Development 开发!先用上架助手在certificates里面生成一个p12文件在profiles里面生成mobileprovision文件就欧克了 需 ...
- DHorse v1.5.1 发布,基于 k8s 的发布平台
版本说明 新增特性 支持k8s的v1.30.x版本: 优化特性 优化回滚功能: 修复注册来源的回滚问题: 新增和修改应用时校验应用名: 升级kubernetes-client至v6.13.0: 调整部 ...
- 「C++」复杂模拟【壹】
建议开启目录食用 阅读本文之前建议您先看这里,如果您已经看完了,那么就可以放心大胆的学习本文了. 我认为其实本文的难度还是比较大的,今天我们题是来自山东省省选,所以建议大家谨慎阅读,如果您是专业程序员 ...
- 海思SDK 学习 :002-实例代码分析
背景 需要了解 海思HI35xx平台软件开发快速入门之背景知识,为了方便测试,还需要了解 海思SDK 的安装 知识 由于海思的应用程序启动 MPP 业务前,必须完成 MPP 系统初始化工作.同理,应用 ...
- (要做的事情)利用MNIST识别自己创建的手写数据
(要做的事情)利用MNIST识别自己创建的手写数据 看懂MNIST 进阶教程,了解CNN
- C#语言编写的仅有8KB大小的简易贪吃蛇开源游戏
前言 今天大姚给大家分享一款由C#语言编写的仅有8KB大小的简易贪吃蛇开源游戏:SeeSharpSnake. 项目特点 该仓库中的项目文件和脚本可以用多种不同的配置构建相同的游戏,每个配置生成的输出大 ...
- scarpy基础
1. 创建项目 scrapy startproject 项目名称 2. 进入项目 cd 项目名称 3. 创建爬虫 scrapy genspider 名字 域名 4. 可能需要start_urls,修改 ...
- 函数式编程(Lambda、Stream流、Optional等)
# 声明 文档来源:Github@shuhongfan 源文档:B站UP主:三更草堂 # 函数式编程-Stream流 # 概述 # 为什么学? 基操,否则看不懂别人写的优雅代码 简化代码,不想看到有些 ...