common/atari_wrappers.py模块代码如下;

import numpy as np
import os
os.environ.setdefault('PATH', '')
from collections import deque
import gym
from gym import spaces
import cv2
cv2.ocl.setUseOpenCL(False)
from .wrappers import TimeLimit class NoopResetEnv(gym.Wrapper):
def __init__(self, env, noop_max=30):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
"""
gym.Wrapper.__init__(self, env)
self.noop_max = noop_max
self.override_num_noops = None
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == 'NOOP' def reset(self, **kwargs):
""" Do no-op action for a number of steps in [1, noop_max]."""
self.env.reset(**kwargs)
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101
assert noops > 0
obs = None
for _ in range(noops):
obs, _, done, _ = self.env.step(self.noop_action)
if done:
obs = self.env.reset(**kwargs)
return obs def step(self, ac):
return self.env.step(ac) class FireResetEnv(gym.Wrapper):
def __init__(self, env):
"""Take action on reset for environments that are fixed until firing."""
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3 def reset(self, **kwargs):
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset(**kwargs)
return obs def step(self, ac):
return self.env.step(ac) class EpisodicLifeEnv(gym.Wrapper):
def __init__(self, env):
"""Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
"""
gym.Wrapper.__init__(self, env)
self.lives = 0
self.was_real_done = True def step(self, action):
obs, reward, done, info = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
# for Qbert sometimes we stay in lives == 0 condition for a few frames
# so it's important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True
self.lives = lives
return obs, reward, done, info def reset(self, **kwargs):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset(**kwargs)
else:
# no-op step to advance from terminal/lost life state
obs, _, _, _ = self.env.step(0)
self.lives = self.env.unwrapped.ale.lives()
return obs class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env, skip=4):
"""Return only every `skip`-th frame"""
gym.Wrapper.__init__(self, env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
self._skip = skip def step(self, action):
"""Repeat action, sum reward, and max over last observations."""
total_reward = 0.0
done = None
for i in range(self._skip):
obs, reward, done, info = self.env.step(action)
if i == self._skip - 2: self._obs_buffer[0] = obs
if i == self._skip - 1: self._obs_buffer[1] = obs
total_reward += reward
if done:
break
# Note that the observation on the done=True frame
# doesn't matter
max_frame = self._obs_buffer.max(axis=0) return max_frame, total_reward, done, info def reset(self, **kwargs):
return self.env.reset(**kwargs) class ClipRewardEnv(gym.RewardWrapper):
def __init__(self, env):
gym.RewardWrapper.__init__(self, env) def reward(self, reward):
"""Bin reward to {+1, 0, -1} by its sign."""
return np.sign(reward) class WarpFrame(gym.ObservationWrapper):
def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None):
"""
Warp frames to 84x84 as done in the Nature paper and later work. If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which
observation should be warped.
"""
super().__init__(env)
self._width = width
self._height = height
self._grayscale = grayscale
self._key = dict_space_key
if self._grayscale:
num_colors = 1
else:
num_colors = 3 new_space = gym.spaces.Box(
low=0,
high=255,
shape=(self._height, self._width, num_colors),
dtype=np.uint8,
)
if self._key is None:
original_space = self.observation_space
self.observation_space = new_space
else:
original_space = self.observation_space.spaces[self._key]
self.observation_space.spaces[self._key] = new_space
assert original_space.dtype == np.uint8 and len(original_space.shape) == 3 def observation(self, obs):
if self._key is None:
frame = obs
else:
frame = obs[self._key] if self._grayscale:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = cv2.resize(
frame, (self._width, self._height), interpolation=cv2.INTER_AREA
)
if self._grayscale:
frame = np.expand_dims(frame, -1) if self._key is None:
obs = frame
else:
obs = obs.copy()
obs[self._key] = frame
return obs class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Stack k last frames. Returns lazy array, which is much more memory efficient. See Also
--------
baselines.common.atari_wrappers.LazyFrames
"""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype) def reset(self):
ob = self.env.reset()
for _ in range(self.k):
self.frames.append(ob)
return self._get_ob() def step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._get_ob(), reward, done, info def _get_ob(self):
assert len(self.frames) == self.k
return LazyFrames(list(self.frames)) class ScaledFloatFrame(gym.ObservationWrapper):
def __init__(self, env):
gym.ObservationWrapper.__init__(self, env)
self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32) def observation(self, observation):
# careful! This undoes the memory optimization, use
# with smaller replay buffers only.
return np.array(observation).astype(np.float32) / 255.0 class LazyFrames(object):
def __init__(self, frames):
"""This object ensures that common frames between the observations are only stored once.
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
buffers. This object should only be converted to numpy array before being passed to the model. You'd not believe how complex the previous solution was."""
self._frames = frames
self._out = None def _force(self):
if self._out is None:
self._out = np.concatenate(self._frames, axis=-1)
self._frames = None
return self._out def __array__(self, dtype=None):
out = self._force()
if dtype is not None:
out = out.astype(dtype)
return out def __len__(self):
return len(self._force()) def __getitem__(self, i):
return self._force()[i] def count(self):
frames = self._force()
return frames.shape[frames.ndim - 1] def frame(self, i):
return self._force()[..., i] def make_atari(env_id, max_episode_steps=None):
env = gym.make(env_id)
assert 'NoFrameskip' in env.spec.id
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps=max_episode_steps)
return env def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
"""Configure environment for DeepMind-style Atari.
"""
if episode_life:
env = EpisodicLifeEnv(env)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = WarpFrame(env)
if scale:
env = ScaledFloatFrame(env)
if clip_rewards:
env = ClipRewardEnv(env)
if frame_stack:
env = FrameStack(env, 4)
return env

该模块功能为对gym的env进行包装,也就是修改reset函数,step函数,同时也可以直接修改observation和reward的值,其实现原理为继承gym环境类的gym.Wrapper gym.RewardWrappergym.ObservationWrapper

其中,覆盖gym.Wrapper中的reset函数和step函数可以实现对env的操作包装,对gym.ObservationWrapper中的observation函数进行覆盖可以实现对环境返回的状态进行修改,对gym.RewardWrapper中的reward函数进行覆盖则可以实现对环境返回的奖励值进行修改的目的。

NoopResetEnv类只对reset函数进行覆盖,操作为在reset的时候随机进行一定数量的noop操作,目的是使游戏开始时的初始状态具有一定的变动。

FireResetEnv类只对reset函数进行覆盖,在reset操作的时候加入动作1和动作2,因为有的游戏需要有动作1作为开始键。

需要注意的是上面两个类在reset操作的时候不会出现游戏终止的情况,即dong=True,但是也写了done=True情况的处理操作,不过该部分可以忽略。

至于FireResetEnv为什么在reset的时候加入动作2,也就是env.step(2)还是没有找到解释的地方。

EpisodicLifeEnv类对step函数和reset函数都进行了覆盖,也就是考虑了游戏中有多条命的情况,如果是总游戏没有结束但是一个游戏live结束也对其进行reset操作,不过如果是总的游戏结束则调用包装的内层类reset,如果是lives减一但是不等于0的情况则进行0动作的操作,也就是fire操作,即 env.step(0)。在step函数中判断是总游戏回合结束还是丢失了一个游戏生命。

MaxAndSkipEnv类,进行重复动作并对得到的最近的两个observation进行取最大值,如默认的skip输入值为4,则进行4次的动作重复,也就是MaxAndSkipEnv类对象接受到一个动作的step则会调用内层包装的env进行4次相同action的step。

由于调用内层的env进行skip数量的step会得到4个observation,对最新得到的两个observation进行取max操作后作为最终的observation返回。

对于这个里面的step操作有一些个人的看法,感觉这个官方给出的设计有些瑕疵,自己的修改如下:

    def step(self, action):
"""Repeat action, sum reward, and max over last observations."""
total_reward = 0.0
done = None
for i in range(self._skip):
obs, reward, done, info = self.env.step(action)
self._obs_buffer[i%2] = obs
total_reward += reward
if done:
break
# Note that the observation on the done=True frame
# doesn't matter
max_frame = self._obs_buffer.max(axis=0) return max_frame, total_reward, done, info

个人的修改主要是考虑了当done=True时还没有对_obs_buffer进行填充的情况,其实这个修改对性能是没有影响的,这样改并不会提升性能,这个只不过可能有些强迫症类型的修改了。同时我们需要注意在step操作中调用内层env进行的skip次的step所得的奖励进行的加和。

类ClipRewardEnv,对step返回的reward进行了处理,也就是说如果有其他的Env类对ClipRewardEnv进行包装,那么ClipRewardEnv向上返回的reward只会为-1,0, +1 。

类WarpFrame(gym.ObservationWrapper)对所有内层的Env的observation进行处理,进行rgb转为灰度图后进行resize操作。需要注意的是rgb图转为灰度图后为保持ndim维度不变会在-1维度上进行扩充。

该类中需要注意的一个地方:

If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which
observation should be warped.

dict_space_key可能是给某种环境类型使用的,该种类型的环境返回的observation不是np.array类型而是dict类型,如果需要获得observation中的np.array就需要调用dict_space_key,即observation[dict_space_key]。

把rgb图片转为gray灰度图后维度变少了,需要在-1维度上进行维度扩展。

        if self._grayscale:
frame = np.expand_dims(frame, -1)

类FrameStack(gym.Wrapper) , 将多个图片在-1维度上进行拼接,在reset函数时将内部传入进行包装的env的reset后返回的observation进行k次copy,其中k个图片的存储使用队列形式  deque([], maxlen=k) 。

    def step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._get_ob(), reward, done, info

每个step得到的observation都用队列存储,最后向上返回的observation则是使用self._get_ob()生成的。

    def _get_ob(self):
assert len(self.frames) == self.k
return LazyFrames(list(self.frames))

可以看到每个observation都是k个图片构成的队列生成的LazyFrames对象。

类class LazyFrames(object), 将需要拼接的图片队列转为np.array类型。

类ScaledFloatFrame(gym.ObservationWrapper),将np.uint8类型转为np.float32,范围由0~255转为0~1 。

------------------------------------------------

def make_atari(env_id, max_episode_steps=None):
env = gym.make(env_id)
assert 'NoFrameskip' in env.spec.id
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps=max_episode_steps)
return env def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
"""Configure environment for DeepMind-style Atari.
"""
if episode_life:
env = EpisodicLifeEnv(env)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = WarpFrame(env)
if scale:
env = ScaledFloatFrame(env)
if clip_rewards:
env = ClipRewardEnv(env)
if frame_stack:
env = FrameStack(env, 4)
return env

上面这两个类非标选择不同的包装器对env进行包装,也就是deepmind类型的atari环境包装和非deepmind的atari环境包装,这两个用法暂时不是很明白。

=============================================

baselines算法库common/atari_wrappers.py模块分析的更多相关文章

  1. openstack 中 log模块分析

    1 . 所在模块,一般在openstack/common/log.py,其实最主要的还是调用了python中的logging模块: 入口函数在 def setup(product_name, vers ...

  2. 【Python】【Web.py】详细解读Python的web.py框架下的application.py模块

    详细解读Python的web.py框架下的application.py模块   这篇文章主要介绍了Python的web.py框架下的application.py模块,作者深入分析了web.py的源码, ...

  3. Python标准库笔记(9) — functools模块

    functools 作用于函数的函数 functools 模块提供用于调整或扩展函数和其他可调用对象的工具,而无需完全重写它们. 装饰器 partial 类是 functools 模块提供的主要工具, ...

  4. python标准库介绍——12 time 模块详解

    ==time 模块== ``time`` 模块提供了一些处理日期和一天内时间的函数. 它是建立在 C 运行时库的简单封装. 给定的日期和时间可以被表示为浮点型(从参考时间, 通常是 1970.1.1 ...

  5. mahout算法库(四)

    mahout算法库 分为三大块 1.聚类算法 2.协同过滤算法(一般用于推荐) 协同过滤算法也可以称为推荐算法!!! 3.分类算法 算法类 算法名 中文名 分类算法               Log ...

  6. scikit-learn 支持向量机算法库使用小结

    之前通过一个系列对支持向量机(以下简称SVM)算法的原理做了一个总结,本文从实践的角度对scikit-learn SVM算法库的使用做一个小结.scikit-learn SVM算法库封装了libsvm ...

  7. OpenRisc-43-or1200的IF模块分析

    引言 “喂饱饥饿的CPU”,是计算机体系结构设计者时刻要考虑的问题.要解决这个问题,方法大体可分为两部分,第一就是利用principle of locality而引进的cache技术,缩短取指时间,第 ...

  8. 【转】python模块分析之unittest测试(五)

    [转]python模块分析之unittest测试(五) 系列文章 python模块分析之random(一) python模块分析之hashlib加密(二) python模块分析之typing(三) p ...

  9. 【转】python模块分析之hashlib加密(二)

    [转]python模块分析之hashlib加密(二) hashlib模块是用来对字符串进行hash加密的模块,明文与密文是一一对应不变的关系:用于注册.登录时用户名.密码等加密使用.一.函数分析:1. ...

  10. 【转】python之random模块分析(一)

    [转]python之random模块分析(一) random是python产生伪随机数的模块,随机种子默认为系统时钟.下面分析模块中的方法: 1.random.randint(start,stop): ...

随机推荐

  1. Prometheus + Grafana (2) mysql、redis、Docker容器、服务端点以及预警

    接着上一节 <Prometheus + Grafana (1) 监控 >,我们继续探讨 Prometheus + Grafana 的复杂应用 实现目标 这节我们的目标是搭建一个多维度监控微 ...

  2. 详解Web应用安全系列(2)注入漏洞之XSS攻击

    上一篇介绍了SQL注入漏洞,今天我们来介绍另一个注入漏洞,即XSS跨站脚本攻击.XSS 全称(Cross Site Scripting) 跨站脚本攻击, 是Web应用中常见的漏洞.指攻击者在网页中嵌入 ...

  3. 《Vue3.x +TpyeScript实践指南》勘误

    图书出版已有一段时间,书中已发现错误如下: 书的第14页,倒数第3行,npm init -y命令中,init和-y之间应该有个空格: 书的第32页,代码的第1行,应该为模板字符串符号 `,我看印刷的是 ...

  4. bs4解析-优美图库

    import requests from bs4 import BeautifulSoup url = 'http://www.umeituku.com/bizhitupian/meinvbizhi/ ...

  5. vue大型电商项目尚品汇(前台篇)day03

    堆积了两天一起发的,先祝大家节日快乐 后面任务很繁重,还有登录注册组件还有后台管理页面,真的繁重,我现在感觉每天全天时间都在学都不一定学得完,主要想在六月一号之前把整个项目过一遍.看看能不能创造奇迹 ...

  6. 分布式文件系统 FastDFS 整理

    1.FastDFS 1.1.了解基础概念 1.1.1.什么是分布式文件系统? 全称:Distributed File System,即简称的DFS 这个东西可以是一个软件,也可以说是服务器,和tomc ...

  7. computed 和 watch 的区别和运用的场景?

    computed: 是计算属性,依赖其它属性值,并且 computed 的值有缓存,只有它依赖的属性值发生改变,下一次获取 computed 的值时才会重新计算 computed 的值: watch: ...

  8. Spring AOP面向切面编程核心概念

    横切关注点 对那些方法进行拦截,拦截后怎么处理,这些就叫横切关注点 比如:权限认证.日志.事务 通知 Advice 在特定的切入点上执行的增强处理,有5种通知 用途:记录日志.控制事务.提前编写好通用 ...

  9. 机器学习策略篇:快速搭建你的第一个系统,并进行迭代(Build your first system quickly, then iterate)

    快速搭建的第一个系统,并进行迭代 如果正在考虑建立一个新的语音识别系统,其实可以走很多方向,可以优先考虑很多事情. 比如,有一些特定的技术,可以让语音识别系统对嘈杂的背景更加健壮,嘈杂的背景可能是说咖 ...

  10. 云服务器安装宝塔Linux面板教程(建议收藏)

    ​ 一.简介 宝塔面板是一款简单好用的服务器运维面板.它支持一键LAMP/LNMP/集群/监控/网站/FTP/数据库/JAVA等100多项服务器管理功能.对于新手用云服务器来建站的话,宝塔面板是一个非 ...