强化学习从基础到进阶-案例与实践[4.2]:深度Q网络DQN-Cart pole游戏展示

  • 强化学习(Reinforcement learning,简称RL)是机器学习中的一个领域,区别与监督学习和无监督学习,强调如何基于环境而行动,以取得最大化的预期利益。
  • 基本操作步骤:智能体agent在环境environment中学习,根据环境的状态state(或观测到的observation),执行动作action,并根据环境的反馈reward(奖励)来指导更好的动作。

比如本项目的Cart pole小游戏中,agent就是动图中的杆子,杆子有向左向右两种action

## 安装依赖
!pip install pygame
!pip install gym
!pip install atari_py
!pip install parl
import gym
import os
import random
import collections import paddle
import paddle.nn as nn
import numpy as np
import paddle.nn.functional as F

1.经验回放部分

经验回放主要做的事情是:把结果存入经验池,然后经验池中随机取出一条结果进行训练。

这样做有两个好处:

  1. 减少样本之间的关联性
  2. 提高样本的利用率

之所以加入experience replay是因为样本是从游戏中的连续帧获得的,这与简单的reinforcement learning问题相比,样本的关联性大了很多,如果没有experience replay,算法在连续一段时间内基本朝着同一个方向做gradient descent,那么同样的步长下这样直接计算gradient就有可能不收敛。因此experience replay是从一个memory pool中随机选取了一些expeirence,然后再求梯度,从而避免了这个问题。

class ReplayMemory(object):
def __init__(self, max_size):
self.buffer = collections.deque(maxlen=max_size) # 增加一条经验到经验池中
def append(self, exp):
self.buffer.append(exp) # 从经验池中选取N条经验出来
def sample(self, batch_size):
mini_batch = random.sample(self.buffer, batch_size)
obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = [], [], [], [], [] for experience in mini_batch:
s, a, r, s_p, done = experience
obs_batch.append(s)
action_batch.append(a)
reward_batch.append(r)
next_obs_batch.append(s_p)
done_batch.append(done) return np.array(obs_batch).astype('float32'), np.array(action_batch).astype('float32'), np.array(reward_batch).astype('float32'), np.array(next_obs_batch).astype('float32'), np.array(done_batch).astype('float32') def __len__(self):
return len(self.buffer)

2.DQN

DQN算法较普通算法在经验回放和固定Q目标有了较大的改进,主要原因:

  • 经验回放:他充分利用了off-colicp的优势,通过训练把结果(成绩)存入Q表格,然后随机从表格中取出一条结果进行优化。这样子一方面可以:减少样本之间的关联性另一方面:提高样本的利用率 注:训练结果会存进Q表格,当Q表格满了以后,存进来的数据会把最早存进去的数据“挤出去”(弹出)
  • 固定Q目标他解决了算法更新不平稳的问题。 和监督学习做比较,监督学习的最终值要逼近实际结果,这个结果是固定的,但是我们的DQN却不是,他的目标值是经过神经网络以后的一个值,那么这个值是变动的不好拟合,怎么办,DQN团队想到了一个很好的办法,让这个值在一定时间里面保持不变,这样子这个目标就可以确定了,然后目标值更新以后更加接近实际结果,可以更好的进行训练。

3.模型Model

这里的模型可以根据自己的需求选择不同的神经网络组建。

DQN用来定义前向(Forward)网络,可以自由的定制自己的网络结构。

class DQN(nn.Layer):
def __init__(self, outputs):
super(DQN, self).__init__()
self.linear1 = nn.Linear(in_features=4, out_features=128)
self.linear2 = nn.Linear(in_features=128, out_features=24)
self.linear3 = nn.Linear(in_features=24, out_features=outputs) def forward(self, x):
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
return x

4.智能体Agent的学习函数

这里包括模型探索与模型训练两个部分

Agent负责算法与环境的交互,在交互过程中把生成的数据提供给Algorithm来更新模型(Model),数据的预处理流程也一般定义在这里。

def sample(obs, MODEL):
global E_GREED
global ACTION_DIM
global E_GREED_DECREMENT
sample = np.random.rand() # 产生0~1之间的小数
if sample < E_GREED:
act = np.random.randint(ACTION_DIM) # 探索:每个动作都有概率被选择
else:
obs = np.expand_dims(obs, axis=0)
obs = paddle.to_tensor(obs, dtype='float32')
act = MODEL(obs)
act = np.argmax(act.numpy()) # 选择最优动作
E_GREED = max(0.01, E_GREED - E_GREED_DECREMENT) # 随着训练逐步收敛,探索的程度慢慢降低
return act def learn(obs, act, reward, next_obs, terminal, TARGET_MODEL, MODEL):
global global_step
# 每隔200个training steps同步一次model和target_model的参数
if global_step % 50 == 0:
TARGET_MODEL.load_dict(MODEL.state_dict())
global_step += 1 obs = np.array(obs).astype('float32')
next_obs = np.array(next_obs).astype('float32')
# act = np.expand_dims(act, -1)
cost = optimize_model(obs, act, reward, next_obs,
terminal, TARGET_MODEL, MODEL) # 训练一次网络
return cost def optimize_model(obs, action, reward, next_obs, terminal, TARGET_MODEL, MODEL):
"""
使用DQN算法更新self.model的value网络
"""
# 从target_model中获取 max Q' 的值,用于计算target_Q
global E_GREED
global ACTION_DIM
global E_GREED_DECREMENT
global GAMMA
global LEARNING_RATE
global opt opt = paddle.optimizer.Adam(learning_rate=LEARNING_RATE,
parameters=MODEL.parameters()) # 优化器(动态图) obs = paddle.to_tensor(obs)
next_obs = paddle.to_tensor(next_obs) next_pred_value = TARGET_MODEL(next_obs).detach()
best_v = paddle.max(next_pred_value, axis=1)
target = reward + (1.0 - terminal) * GAMMA * best_v.numpy()
target = paddle.to_tensor(target)
pred_value = MODEL(obs) # 获取Q预测值
# 将action转onehot向量,比如:3 => [0,0,0,1,0]
action = paddle.to_tensor(action.astype('int32'))
action_onehot = F.one_hot(action, ACTION_DIM)
action_onehot = paddle.cast(action_onehot, dtype='float32')
# 下面一行是逐元素相乘,拿到action对应的 Q(s,a)
pred_action_value = paddle.sum(paddle.multiply(action_onehot, pred_value), axis=1)
# 计算 Q(s,a) 与 target_Q的均方差,得到loss
cost = F.square_error_cost(pred_action_value, target)
cost = paddle.mean(cost)
avg_cost = cost
cost.backward()
opt.step()
opt.clear_grad() return avg_cost.numpy()

5.模型梯度更新算法

def run_train(env, rpm, TARGET_MODEL, MODEL):
MODEL.train()
TARGET_MODEL.train()
total_reward = 0
obs = env.reset() global global_step
while True:
global_step += 1
# 获取随机动作和执行游戏
action = sample(obs, MODEL) next_obs, reward, isOver, info = env.step(action) # 记录数据
rpm.append((obs, action, reward, next_obs, isOver)) # 在预热完成之后,每隔LEARN_FREQ步数就训练一次
if (len(rpm) > MEMORY_WARMUP_SIZE) and (global_step % LEARN_FREQ == 0):
(batch_obs, batch_action, batch_reward, batch_next_obs, batch_isOver) = rpm.sample(BATCH_SIZE)
train_loss = learn(batch_obs, batch_action, batch_reward,
batch_next_obs, batch_isOver, TARGET_MODEL, MODEL) total_reward += reward
obs = next_obs.astype('float32') # 结束游戏
if isOver:
break
return total_reward def evaluate(model, env, render=False):
model.eval()
eval_reward = []
for i in range(5):
obs = env.reset()
episode_reward = 0
while True:
obs = np.expand_dims(obs, axis=0)
obs = paddle.to_tensor(obs, dtype='float32')
action = model(obs)
action = np.argmax(action.numpy())
obs, reward, done, _ = env.step(action)
episode_reward += reward
if render:
env.render()
if done:
break
eval_reward.append(episode_reward)
return np.mean(eval_reward)

6.训练函数与验证函数

设置超参数

LEARN_FREQ = 5  # 训练频率,不需要每一个step都learn,攒一些新增经验后再learn,提高效率
MEMORY_SIZE = 20000 # replay memory的大小,越大越占用内存
MEMORY_WARMUP_SIZE = 200 # replay_memory 里需要预存一些经验数据,再开启训练
BATCH_SIZE = 32 # 每次给agent learn的数据数量,从replay memory随机里sample一批数据出来
LEARNING_RATE = 0.001 # 学习率大小
GAMMA = 0.99 # reward 的衰减因子,一般取 0.9 到 0.999 不等 E_GREED = 0.1 # 探索初始概率
E_GREED_DECREMENT = 1e-6 # 在训练过程中,降低探索的概率
MAX_EPISODE = 20000 # 训练次数
SAVE_MODEL_PATH = "models/save" # 保存模型路径
OBS_DIM = None
ACTION_DIM = None
global_step = 0
def main():
global OBS_DIM
global ACTION_DIM train_step_list = []
train_reward_list = []
evaluate_step_list = []
evaluate_reward_list = [] # 初始化游戏
env = gym.make('CartPole-v0')
# 图像输入形状和动作维度
action_dim = env.action_space.n
obs_dim = env.observation_space.shape
OBS_DIM = obs_dim
ACTION_DIM = action_dim
max_score = -int(1e4) # 创建存储执行游戏的内存
rpm = ReplayMemory(MEMORY_SIZE)
MODEL = DQN(ACTION_DIM)
TARGET_MODEL = DQN(ACTION_DIM)
# if os.path.exists(os.path.dirname(SAVE_MODEL_PATH)):
# MODEL_DICT = paddle.load(SAVE_MODEL_PATH+'.pdparams')
# MODEL.load_dict(MODEL_DICT) # 加载模型参数
print("filling memory...")
while len(rpm) < MEMORY_WARMUP_SIZE:
run_train(env, rpm, TARGET_MODEL, MODEL)
print("filling memory done") # 开始训练
episode = 0 print("start training...")
# 训练max_episode个回合,test部分不计算入episode数量
while episode < MAX_EPISODE:
# train part
for i in range(0, int(50)):
# First we need a state
total_reward = run_train(env, rpm, TARGET_MODEL, MODEL)
episode += 1 # print("episode:{} reward:{}".format(episode, str(total_reward))) # test part
# print("start evaluation...")
eval_reward = evaluate(TARGET_MODEL, env)
print('episode:{} e_greed:{} test_reward:{}'.format(episode, E_GREED, eval_reward)) evaluate_step_list.append(episode)
evaluate_reward_list.append(eval_reward) # if eval_reward > max_score or not os.path.exists(os.path.dirname(SAVE_MODEL_PATH)):
# max_score = eval_reward
# paddle.save(TARGET_MODEL.state_dict(), SAVE_MODEL_PATH+'.pdparams') # 保存模型 if __name__ == '__main__':
main()

filling memory...

filling memory done

start training...

episode:50 e_greed:0.0992949999999993 test_reward:9.0

episode:100 e_greed:0.0987909999999988 test_reward:9.8

episode:150 e_greed:0.09827199999999828 test_reward:10.0

episode:200 e_greed:0.09777599999999778 test_reward:8.8

episode:250 e_greed:0.09726999999999728 test_reward:9.0

episode:300 e_greed:0.09676199999999677 test_reward:10.0

episode:350 e_greed:0.0961919999999962 test_reward:14.8

项目链接fork一下即可运行

https://www.heywhale.com/mw/project/649e7d3f70567260f8f11d2b

更多优质内容请关注公号:汀丶人工智能

强化学习从基础到进阶-案例与实践[4.2]:深度Q网络DQN-Cart pole游戏展示的更多相关文章

  1. 强化学习入门基础-马尔可夫决策过程(MDP)

    作者:YJLAugus 博客: https://www.cnblogs.com/yjlaugus 项目地址:https://github.com/YJLAugus/Reinforcement-Lear ...

  2. 【算法总结】强化学习部分基础算法总结(Q-learning DQN PG AC DDPG TD3)

    总结回顾一下近期学习的RL算法,并给部分实现算法整理了流程图.贴了代码. 1. value-based 基于价值的算法 基于价值算法是通过对agent所属的environment的状态或者状态动作对进 ...

  3. [原创]java WEB学习笔记21:MVC案例完整实践(part 2)---DAO层设计

    本博客为原创:综合 尚硅谷(http://www.atguigu.com)的系统教程(深表感谢)和 网络上的现有资源(博客,文档,图书等),资源的出处我会标明 本博客的目的:①总结自己的学习过程,相当 ...

  4. [原创]java WEB学习笔记20:MVC案例完整实践(part 1)---MVC架构分析

    本博客为原创:综合 尚硅谷(http://www.atguigu.com)的系统教程(深表感谢)和 网络上的现有资源(博客,文档,图书等),资源的出处我会标明 本博客的目的:①总结自己的学习过程,相当 ...

  5. MySQL学习笔记——基础与进阶篇

    目录 一.###MySQL登录和退出 二.###MySQL常用命令 三.###MySQL语法规范 四.###基础查询 五.###条件查询 六.###排序查询 七.###常见函数的学习 八.###分组查 ...

  6. [原创]java WEB学习笔记26:MVC案例完整实践(part 7)---修改的设计和实现

    本博客为原创:综合 尚硅谷(http://www.atguigu.com)的系统教程(深表感谢)和 网络上的现有资源(博客,文档,图书等),资源的出处我会标明 本博客的目的:①总结自己的学习过程,相当 ...

  7. [原创]java WEB学习笔记25:MVC案例完整实践(part 6)---新增操作的设计与实现

    本博客为原创:综合 尚硅谷(http://www.atguigu.com)的系统教程(深表感谢)和 网络上的现有资源(博客,文档,图书等),资源的出处我会标明 本博客的目的:①总结自己的学习过程,相当 ...

  8. [原创]java WEB学习笔记24:MVC案例完整实践(part 5)---删除操作的设计与实现

    本博客为原创:综合 尚硅谷(http://www.atguigu.com)的系统教程(深表感谢)和 网络上的现有资源(博客,文档,图书等),资源的出处我会标明 本博客的目的:①总结自己的学习过程,相当 ...

  9. [原创]java WEB学习笔记23:MVC案例完整实践(part 4)---模糊查询的设计与实现

    本博客为原创:综合 尚硅谷(http://www.atguigu.com)的系统教程(深表感谢)和 网络上的现有资源(博客,文档,图书等),资源的出处我会标明 本博客的目的:①总结自己的学习过程,相当 ...

  10. [原创]java WEB学习笔记22:MVC案例完整实践(part 3)---多个请求对应一个Servlet解析

    本博客为原创:综合 尚硅谷(http://www.atguigu.com)的系统教程(深表感谢)和 网络上的现有资源(博客,文档,图书等),资源的出处我会标明 本博客的目的:①总结自己的学习过程,相当 ...

随机推荐

  1. Java 线程间通信 —— 管道输入 / 输出流

    本文部分摘自<Java 并发编程的艺术> 管道输入 / 输出流 管道输入 / 输出流和普通的文件输入 / 输出流或者网络输入 / 输出流不同之处在于,它主要用于线程之间的数据传输,而传输媒 ...

  2. pip 的高阶玩法

    pip 的高阶玩法 pip 应该是大家最熟悉的 Python 包安装与管理工具了,但是除了pip install 这个最常用的命令,还有很多有用的玩法.这里就介绍几个我平时会用到的,希望对大家有所帮助 ...

  3. Android WebView 踩坑日记,字体怎么突然变小了???

    背景 最近,端内在做 webView 统一的时候,个性签名中的 WebView 替换为 CustomWebView 之后,发现字体突然变小. 一开始不知道是什么原因,通过二分法查找最近的提交,排查之后 ...

  4. 简单的git使用命令

    一.Git简介       Git(读音为/gɪt/.)是一个开源的分布式版本控制系统,可以有效.高速地处理从很小到非常大的项目版本管理.Git 是 Linus Torvalds 为了帮助管理 Lin ...

  5. 面试官:SpringBoot如何实现缓存预热?

    缓存预热是指在 Spring Boot 项目启动时,预先将数据加载到缓存系统(如 Redis)中的一种机制. 那么问题来了,在 Spring Boot 项目启动之后,在什么时候?在哪里可以将数据加载到 ...

  6. wireshark 报文颜色

    在使用wireshark抓包分析的过程中,默认会对不同的包进行着色,截图如下: 对不同的颜色有了解,可快速的过滤包或分析请求. 菜单栏选择视图-->着色规则,即可看到不同颜色代表的含义: 大致可 ...

  7. 【SHELL】查找文件并删除

    find . -iname file-name |xargs -I % rm -rf %

  8. Harbor镜像仓库的导出与整理之二

    Harbor镜像仓库的导出与整理之二 背景 前几天参照大神的blog进行了一下harbor的镜像列表的获取与下载. 当时发现一个很诡异的问题. 实际上镜像仓库里面的镜像很多. 但是导出和列表里面的却很 ...

  9. [转帖]TiDB修改配置参数

    https://www.jianshu.com/p/2ecdb4642579 在TiDB 中,"修改配置参数"似乎是个不精准的说法,它实际包含了以下内容: 修改 TiDB 的系统变 ...

  10. [转帖]DOCKER默认网段和主机网段冲突解决

    https://www.cnblogs.com/yinliang/p/13189334.html 一. docker默认网卡docker0 172.17.0.0可能会与主机冲突,这时候需要修改dock ...