common/atari_wrappers.py模块代码如下;

import numpy as np
import os
os.environ.setdefault('PATH', '')
from collections import deque
import gym
from gym import spaces
import cv2
cv2.ocl.setUseOpenCL(False)
from .wrappers import TimeLimit class NoopResetEnv(gym.Wrapper):
def __init__(self, env, noop_max=30):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
"""
gym.Wrapper.__init__(self, env)
self.noop_max = noop_max
self.override_num_noops = None
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == 'NOOP' def reset(self, **kwargs):
""" Do no-op action for a number of steps in [1, noop_max]."""
self.env.reset(**kwargs)
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101
assert noops > 0
obs = None
for _ in range(noops):
obs, _, done, _ = self.env.step(self.noop_action)
if done:
obs = self.env.reset(**kwargs)
return obs def step(self, ac):
return self.env.step(ac) class FireResetEnv(gym.Wrapper):
def __init__(self, env):
"""Take action on reset for environments that are fixed until firing."""
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3 def reset(self, **kwargs):
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset(**kwargs)
return obs def step(self, ac):
return self.env.step(ac) class EpisodicLifeEnv(gym.Wrapper):
def __init__(self, env):
"""Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
"""
gym.Wrapper.__init__(self, env)
self.lives = 0
self.was_real_done = True def step(self, action):
obs, reward, done, info = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
# for Qbert sometimes we stay in lives == 0 condition for a few frames
# so it's important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True
self.lives = lives
return obs, reward, done, info def reset(self, **kwargs):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset(**kwargs)
else:
# no-op step to advance from terminal/lost life state
obs, _, _, _ = self.env.step(0)
self.lives = self.env.unwrapped.ale.lives()
return obs class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env, skip=4):
"""Return only every `skip`-th frame"""
gym.Wrapper.__init__(self, env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
self._skip = skip def step(self, action):
"""Repeat action, sum reward, and max over last observations."""
total_reward = 0.0
done = None
for i in range(self._skip):
obs, reward, done, info = self.env.step(action)
if i == self._skip - 2: self._obs_buffer[0] = obs
if i == self._skip - 1: self._obs_buffer[1] = obs
total_reward += reward
if done:
break
# Note that the observation on the done=True frame
# doesn't matter
max_frame = self._obs_buffer.max(axis=0) return max_frame, total_reward, done, info def reset(self, **kwargs):
return self.env.reset(**kwargs) class ClipRewardEnv(gym.RewardWrapper):
def __init__(self, env):
gym.RewardWrapper.__init__(self, env) def reward(self, reward):
"""Bin reward to {+1, 0, -1} by its sign."""
return np.sign(reward) class WarpFrame(gym.ObservationWrapper):
def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None):
"""
Warp frames to 84x84 as done in the Nature paper and later work. If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which
observation should be warped.
"""
super().__init__(env)
self._width = width
self._height = height
self._grayscale = grayscale
self._key = dict_space_key
if self._grayscale:
num_colors = 1
else:
num_colors = 3 new_space = gym.spaces.Box(
low=0,
high=255,
shape=(self._height, self._width, num_colors),
dtype=np.uint8,
)
if self._key is None:
original_space = self.observation_space
self.observation_space = new_space
else:
original_space = self.observation_space.spaces[self._key]
self.observation_space.spaces[self._key] = new_space
assert original_space.dtype == np.uint8 and len(original_space.shape) == 3 def observation(self, obs):
if self._key is None:
frame = obs
else:
frame = obs[self._key] if self._grayscale:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = cv2.resize(
frame, (self._width, self._height), interpolation=cv2.INTER_AREA
)
if self._grayscale:
frame = np.expand_dims(frame, -1) if self._key is None:
obs = frame
else:
obs = obs.copy()
obs[self._key] = frame
return obs class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Stack k last frames. Returns lazy array, which is much more memory efficient. See Also
--------
baselines.common.atari_wrappers.LazyFrames
"""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype) def reset(self):
ob = self.env.reset()
for _ in range(self.k):
self.frames.append(ob)
return self._get_ob() def step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._get_ob(), reward, done, info def _get_ob(self):
assert len(self.frames) == self.k
return LazyFrames(list(self.frames)) class ScaledFloatFrame(gym.ObservationWrapper):
def __init__(self, env):
gym.ObservationWrapper.__init__(self, env)
self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32) def observation(self, observation):
# careful! This undoes the memory optimization, use
# with smaller replay buffers only.
return np.array(observation).astype(np.float32) / 255.0 class LazyFrames(object):
def __init__(self, frames):
"""This object ensures that common frames between the observations are only stored once.
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
buffers. This object should only be converted to numpy array before being passed to the model. You'd not believe how complex the previous solution was."""
self._frames = frames
self._out = None def _force(self):
if self._out is None:
self._out = np.concatenate(self._frames, axis=-1)
self._frames = None
return self._out def __array__(self, dtype=None):
out = self._force()
if dtype is not None:
out = out.astype(dtype)
return out def __len__(self):
return len(self._force()) def __getitem__(self, i):
return self._force()[i] def count(self):
frames = self._force()
return frames.shape[frames.ndim - 1] def frame(self, i):
return self._force()[..., i] def make_atari(env_id, max_episode_steps=None):
env = gym.make(env_id)
assert 'NoFrameskip' in env.spec.id
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps=max_episode_steps)
return env def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
"""Configure environment for DeepMind-style Atari.
"""
if episode_life:
env = EpisodicLifeEnv(env)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = WarpFrame(env)
if scale:
env = ScaledFloatFrame(env)
if clip_rewards:
env = ClipRewardEnv(env)
if frame_stack:
env = FrameStack(env, 4)
return env

该模块功能为对gym的env进行包装,也就是修改reset函数,step函数,同时也可以直接修改observation和reward的值,其实现原理为继承gym环境类的gym.Wrapper gym.RewardWrappergym.ObservationWrapper

其中,覆盖gym.Wrapper中的reset函数和step函数可以实现对env的操作包装,对gym.ObservationWrapper中的observation函数进行覆盖可以实现对环境返回的状态进行修改,对gym.RewardWrapper中的reward函数进行覆盖则可以实现对环境返回的奖励值进行修改的目的。

NoopResetEnv类只对reset函数进行覆盖,操作为在reset的时候随机进行一定数量的noop操作,目的是使游戏开始时的初始状态具有一定的变动。

FireResetEnv类只对reset函数进行覆盖,在reset操作的时候加入动作1和动作2,因为有的游戏需要有动作1作为开始键。

需要注意的是上面两个类在reset操作的时候不会出现游戏终止的情况,即dong=True,但是也写了done=True情况的处理操作,不过该部分可以忽略。

至于FireResetEnv为什么在reset的时候加入动作2,也就是env.step(2)还是没有找到解释的地方。

EpisodicLifeEnv类对step函数和reset函数都进行了覆盖,也就是考虑了游戏中有多条命的情况,如果是总游戏没有结束但是一个游戏live结束也对其进行reset操作,不过如果是总的游戏结束则调用包装的内层类reset,如果是lives减一但是不等于0的情况则进行0动作的操作,也就是fire操作,即 env.step(0)。在step函数中判断是总游戏回合结束还是丢失了一个游戏生命。

MaxAndSkipEnv类,进行重复动作并对得到的最近的两个observation进行取最大值,如默认的skip输入值为4,则进行4次的动作重复,也就是MaxAndSkipEnv类对象接受到一个动作的step则会调用内层包装的env进行4次相同action的step。

由于调用内层的env进行skip数量的step会得到4个observation,对最新得到的两个observation进行取max操作后作为最终的observation返回。

对于这个里面的step操作有一些个人的看法,感觉这个官方给出的设计有些瑕疵,自己的修改如下:

    def step(self, action):
"""Repeat action, sum reward, and max over last observations."""
total_reward = 0.0
done = None
for i in range(self._skip):
obs, reward, done, info = self.env.step(action)
self._obs_buffer[i%2] = obs
total_reward += reward
if done:
break
# Note that the observation on the done=True frame
# doesn't matter
max_frame = self._obs_buffer.max(axis=0) return max_frame, total_reward, done, info

个人的修改主要是考虑了当done=True时还没有对_obs_buffer进行填充的情况,其实这个修改对性能是没有影响的,这样改并不会提升性能,这个只不过可能有些强迫症类型的修改了。同时我们需要注意在step操作中调用内层env进行的skip次的step所得的奖励进行的加和。

类ClipRewardEnv,对step返回的reward进行了处理,也就是说如果有其他的Env类对ClipRewardEnv进行包装,那么ClipRewardEnv向上返回的reward只会为-1,0, +1 。

类WarpFrame(gym.ObservationWrapper)对所有内层的Env的observation进行处理,进行rgb转为灰度图后进行resize操作。需要注意的是rgb图转为灰度图后为保持ndim维度不变会在-1维度上进行扩充。

该类中需要注意的一个地方:

If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which
observation should be warped.

dict_space_key可能是给某种环境类型使用的,该种类型的环境返回的observation不是np.array类型而是dict类型,如果需要获得observation中的np.array就需要调用dict_space_key,即observation[dict_space_key]。

把rgb图片转为gray灰度图后维度变少了,需要在-1维度上进行维度扩展。

        if self._grayscale:
frame = np.expand_dims(frame, -1)

类FrameStack(gym.Wrapper) , 将多个图片在-1维度上进行拼接,在reset函数时将内部传入进行包装的env的reset后返回的observation进行k次copy,其中k个图片的存储使用队列形式  deque([], maxlen=k) 。

    def step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._get_ob(), reward, done, info

每个step得到的observation都用队列存储,最后向上返回的observation则是使用self._get_ob()生成的。

    def _get_ob(self):
assert len(self.frames) == self.k
return LazyFrames(list(self.frames))

可以看到每个observation都是k个图片构成的队列生成的LazyFrames对象。

类class LazyFrames(object), 将需要拼接的图片队列转为np.array类型。

类ScaledFloatFrame(gym.ObservationWrapper),将np.uint8类型转为np.float32,范围由0~255转为0~1 。

------------------------------------------------

def make_atari(env_id, max_episode_steps=None):
env = gym.make(env_id)
assert 'NoFrameskip' in env.spec.id
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps=max_episode_steps)
return env def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
"""Configure environment for DeepMind-style Atari.
"""
if episode_life:
env = EpisodicLifeEnv(env)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = WarpFrame(env)
if scale:
env = ScaledFloatFrame(env)
if clip_rewards:
env = ClipRewardEnv(env)
if frame_stack:
env = FrameStack(env, 4)
return env

上面这两个类非标选择不同的包装器对env进行包装,也就是deepmind类型的atari环境包装和非deepmind的atari环境包装,这两个用法暂时不是很明白。

=============================================

baselines算法库common/atari_wrappers.py模块分析的更多相关文章

  1. openstack 中 log模块分析

    1 . 所在模块,一般在openstack/common/log.py,其实最主要的还是调用了python中的logging模块: 入口函数在 def setup(product_name, vers ...

  2. 【Python】【Web.py】详细解读Python的web.py框架下的application.py模块

    详细解读Python的web.py框架下的application.py模块   这篇文章主要介绍了Python的web.py框架下的application.py模块,作者深入分析了web.py的源码, ...

  3. Python标准库笔记(9) — functools模块

    functools 作用于函数的函数 functools 模块提供用于调整或扩展函数和其他可调用对象的工具,而无需完全重写它们. 装饰器 partial 类是 functools 模块提供的主要工具, ...

  4. python标准库介绍——12 time 模块详解

    ==time 模块== ``time`` 模块提供了一些处理日期和一天内时间的函数. 它是建立在 C 运行时库的简单封装. 给定的日期和时间可以被表示为浮点型(从参考时间, 通常是 1970.1.1 ...

  5. mahout算法库(四)

    mahout算法库 分为三大块 1.聚类算法 2.协同过滤算法(一般用于推荐) 协同过滤算法也可以称为推荐算法!!! 3.分类算法 算法类 算法名 中文名 分类算法               Log ...

  6. scikit-learn 支持向量机算法库使用小结

    之前通过一个系列对支持向量机(以下简称SVM)算法的原理做了一个总结,本文从实践的角度对scikit-learn SVM算法库的使用做一个小结.scikit-learn SVM算法库封装了libsvm ...

  7. OpenRisc-43-or1200的IF模块分析

    引言 “喂饱饥饿的CPU”,是计算机体系结构设计者时刻要考虑的问题.要解决这个问题,方法大体可分为两部分,第一就是利用principle of locality而引进的cache技术,缩短取指时间,第 ...

  8. 【转】python模块分析之unittest测试(五)

    [转]python模块分析之unittest测试(五) 系列文章 python模块分析之random(一) python模块分析之hashlib加密(二) python模块分析之typing(三) p ...

  9. 【转】python模块分析之hashlib加密(二)

    [转]python模块分析之hashlib加密(二) hashlib模块是用来对字符串进行hash加密的模块,明文与密文是一一对应不变的关系:用于注册.登录时用户名.密码等加密使用.一.函数分析:1. ...

  10. 【转】python之random模块分析(一)

    [转]python之random模块分析(一) random是python产生伪随机数的模块,随机种子默认为系统时钟.下面分析模块中的方法: 1.random.randint(start,stop): ...

随机推荐

  1. work05

    第一题:分析以下需求,并用代码实现 手机类Phone 属性: 品牌brand 价格price 行为: 打电话call() 发短信sendMessage() 玩游戏playGame() 要求: 1.按照 ...

  2. String和StringBuffer、StringBuilder的区别是什么?String为什么是不可变的

    a.可变性:String类中使用字符数组保存字符串,private  final   char   value[],所以string对象是不可变的.StringBuilder与StringBuffer ...

  3. sftp jsch文件移动备份的思路

    1.jsch jar包不支持mv cp等移动复制的功能,转换思路,sftp下载文件到本地服务器,目录可以考虑使用/年/月/日层级. 2.然后sftp下载操作完毕,记录一张文件操作表,记录下载状态. 3 ...

  4. Flutter 借助SearchDelegate实现搜索页面,实现搜索建议、搜索结果,解决IOS拼音问题

    搜索界面使用Flutter自带的SearchDelegate组件实现,通过魔改实现如下效果: 搜素建议 搜索结果,支持刷新和加载更多 IOS中文输入拼音问题 界面预览 拷贝源码 将SearchDele ...

  5. python 将查询到数据,处理成包含列名和数据的字典类型数据

    try: self.connect_dbserver() self.cursor.execute(sql) res = self.cursor.fetchall() # 返回的是数组的类型 print ...

  6. java 8 stream toMap问题

    最近使用java的stream功能有点多,理由有2: 1)少写了不少代码 2)在性能可以接受的范围内 在巨大的collection基础上使用stream,没有什么经验.而非关键业务上,乐于使用stre ...

  7. 11-Python网络编程

    socket包介绍 Socket又称"套接字",应用程序通常通过"套接字"向网络发出请求或者应答网络请求,使主机间或者一台计算机上的进程间可以通讯. 创建一个T ...

  8. 防止unordered_map 被卡方法

    codeforces 上看到的,mark 一下代码.原作者:neal,原链接:https://codeforces.com/blog/entry/62393 struct custom_hash { ...

  9. 带有ttl的Lru在Rust中的实现及源码解析

    TTL是Time To Live的缩写,通常意味着元素的生存时间是多长. 应用场景 数据库:在redis中我们最常见的就是缓存我们的数据元素,但是我们又不想其保留太长的时间,因为数据时间越长污染的可能 ...

  10. fseek在 fopen 带有'a'模式下不起作用

    关于 fseek 在 追加写模式的注意事项 结论:fseek在 fopen 带有'a'模式的文件指针偏移不起作用. int main(int argc, char *argv[]) { FILE * ...