这个难度有些大,有两个policy,一个负责更新策略,另一个负责提供数据,实际这两个policy是一个东西,用policy1跑出一组数据给新的policy2训练,然后policy2跑数据给新的policy3训练,,,,直到policy(N-1)跑数据给新的policyN训练,过程感觉和DQN比较像,但是模型是actor critic 架构,on-policy转换成off-policy,使用剪切策略来限制策略的更新幅度,off-policy的好处是策略更新快,PPO的优化目标是最大化策略的期望回报,同时避免策略更新过大

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pygame
import sys
from collections import deque # 定义策略网络
class PolicyNetwork(nn.Module):
def __init__(self):
super(PolicyNetwork, self).__init__()
self.fc = nn.Sequential(
nn.Linear(4, 2),
nn.Tanh(),
nn.Linear(2, 2), # CartPole的动作空间为2
nn.Softmax(dim=-1)
) def forward(self, x):
return self.fc(x) # 定义值网络
class ValueNetwork(nn.Module):
def __init__(self):
super(ValueNetwork, self).__init__()
self.fc = nn.Sequential(
nn.Linear(4, 2),
nn.Tanh(),
nn.Linear(2, 1)
) def forward(self, x):
return self.fc(x) # 经验回放缓冲区
class RolloutBuffer:
def __init__(self):
self.states = []
self.actions = []
self.rewards = []
self.dones = []
self.log_probs = [] def store(self, state, action, reward, done, log_prob):
self.states.append(state)
self.actions.append(action)
self.rewards.append(reward)
self.dones.append(done)
self.log_probs.append(log_prob) def clear(self):
self.states = []
self.actions = []
self.rewards = []
self.dones = []
self.log_probs = [] def get_batch(self):
return (
torch.tensor(self.states, dtype=torch.float),
torch.tensor(self.actions, dtype=torch.long),
torch.tensor(self.rewards, dtype=torch.float),
torch.tensor(self.dones, dtype=torch.bool),
torch.tensor(self.log_probs, dtype=torch.float)
) # PPO更新函数
def ppo_update(policy_net, value_net, optimizer_policy, optimizer_value, buffer, epochs=10, gamma=0.99, clip_param=0.2):
states, actions, rewards, dones, old_log_probs = buffer.get_batch()
returns = []
advantages = []
G = 0
adv = 0
dones = dones.to(torch.int)
# print(dones)
for reward, done, value in zip(reversed(rewards), reversed(dones), reversed(value_net(states))):
if done:
G = 0
adv = 0
G = reward + gamma * G #蒙特卡洛回溯G值
delta = reward + gamma * value.item() * (1 - done) - value.item() #TD差分
# adv = delta + gamma * 0.95 * adv * (1 - done) #
adv = delta + adv*(1-done)
returns.insert(0, G)
advantages.insert(0, adv) returns = torch.tensor(returns, dtype=torch.float) #价值
advantages = torch.tensor(advantages, dtype=torch.float)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) #add baseline for _ in range(epochs):
action_probs = policy_net(states)
dist = torch.distributions.Categorical(action_probs)
new_log_probs = dist.log_prob(actions)
ratio = (new_log_probs - old_log_probs).exp()
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantages
actor_loss = -torch.min(surr1, surr2).mean() optimizer_policy.zero_grad()
actor_loss.backward()
optimizer_policy.step() value_loss = (returns - value_net(states)).pow(2).mean() optimizer_value.zero_grad()
value_loss.backward()
optimizer_value.step() # 初始化环境和模型
env = gym.make('CartPole-v1')
policy_net = PolicyNetwork()
value_net = ValueNetwork()
optimizer_policy = optim.Adam(policy_net.parameters(), lr=3e-4)
optimizer_value = optim.Adam(value_net.parameters(), lr=1e-3)
buffer = RolloutBuffer() # Pygame初始化
pygame.init()
screen = pygame.display.set_mode((600, 400))
clock = pygame.time.Clock() draw_on = False
# 训练循环
state = env.reset()
for episode in range(10000): # 训练轮次
done = False
state = state[0]
step= 0
while not done:
step+=1
state_tensor = torch.FloatTensor(state).unsqueeze(0)
action_probs = policy_net(state_tensor)
dist = torch.distributions.Categorical(action_probs)
action = dist.sample()
log_prob = dist.log_prob(action) next_state, reward, done, _ ,_ = env.step(action.item())
buffer.store(state, action.item(), reward, done, log_prob) state = next_state # 实时显示
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
sys.exit() if draw_on:
# 清屏并重新绘制
screen.fill((0, 0, 0))
cart_x = int(state[0] * 100 + 300) # 位置转换为屏幕坐标
pygame.draw.rect(screen, (0, 128, 255), (cart_x, 300, 50, 30))
pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * np.sin(state[2])), 300 - int(50 * np.cos(state[2]))), 5)
pygame.display.flip()
clock.tick(600) if step >10000:
draw_on = True
ppo_update(policy_net, value_net, optimizer_policy, optimizer_value, buffer)
buffer.clear()
state = env.reset()
print(f'Episode {episode} completed {step}.') # 结束训练
env.close()
pygame.quit()

运行效果

PPO近端策略优化玩cartpole游戏的更多相关文章

  1. TensorFlow利用A3C算法训练智能体玩CartPole游戏

    本教程讲解如何使用深度强化学习训练一个可以在 CartPole 游戏中获胜的模型.研究人员使用 tf.keras.OpenAI 训练了一个使用「异步优势动作评价」(Asynchronous Advan ...

  2. DRL 教程 | 如何保持运动小车上的旗杆屹立不倒?TensorFlow利用A3C算法训练智能体玩CartPole游戏

    本教程讲解如何使用深度强化学习训练一个可以在 CartPole 游戏中获胜的模型.研究人员使用 tf.keras.OpenAI 训练了一个使用「异步优势动作评价」(Asynchronous Advan ...

  3. 适合码农工作时玩的游戏:Scrum

    适合码农工作时玩的游戏:Scrum 昨天遇到一个来自微软的面试者,在面试的最后,我简单介绍了一下我们团队使用一周一次的 Scrum 来做项目管理.他回答说:” 我在微软也用 Scrum,不过我们一周两 ...

  4. 玩QQ游戏,见到好几个图像是美女的QQ,就不始玩

    玩QQ游戏,见到好几个图像是美女的QQ,光占坑就是不开始玩 加了一个,发现是传播不良网站的QQ 聊天还是自动的 估计是利用webqq写的程序,也就那几句话来回重复,让你去注册网站什么 可以加这个Q去体 ...

  5. 使用PS3手柄在PC玩Unity3D游戏

    PS3手柄玩Unity游戏 今天把公司的PS3手柄接到PC上,想用手柄试一下玩赛车的感觉,老感觉用键盘按键玩的不爽. 把PS3的手柄接到PC上之后,系统提示正在安装驱动--,百度找资料,如何在PC上使 ...

  6. 伯克利、OpenAI等提出基于模型的元策略优化强化学习

    基于模型的强化学习方法数据效率高,前景可观.本文提出了一种基于模型的元策略强化学习方法,实践证明,该方法比以前基于模型的方法更能够应对模型缺陷,还能取得与无模型方法相近的性能. 引言 强化学习领域近期 ...

  7. 用python玩推理游戏还能掌握基础知识点,有趣又充实,你不试试吗?

    可能更多的人依然还在苦苦的学python各种知识点,但其实同样很多人,玩着游戏就把python学会了.     用python玩推理游戏,是这份python教程中的12个游戏的其中之一. 有关这份Py ...

  8. Linux系统中有趣的命令(可以玩小游戏)

    Linux系统中有趣的命令(可以玩小游戏) 前言 最近,我在看一些关于Linux系统的内容,这里面的内容是真的越学越枯燥,果然学习的过程还是不容易的.记得前几个月初学Linux时,有时候就会碰到小彩蛋 ...

  9. Bert不完全手册3. Bert训练策略优化!RoBERTa & SpanBERT

    之前看过一条评论说Bert提出了很好的双向语言模型的预训练以及下游迁移的框架,但是它提出的各种训练方式槽点较多,或多或少都有优化的空间.这一章就训练方案的改良,我们来聊聊RoBERTa和SpanBER ...

  10. 策略梯度训练cartpole小游戏

    我原来已经安装了anaconda,在此基础上进入cmd进行pip install tensorflow和pip install gym就可以了. 在win10的pycharm做的. policy_gr ...

随机推荐

  1. #二进制拆分,矩阵乘法#洛谷 6569 [NOI Online #3 提高组] 魔法值

    题目 分析 考虑一个点的权值能被统计到答案当且仅当其到1号点的路径条数为奇数条. 那么设 \(dp[i][x][y]\) 表示从 \(x\) 到 \(y\) 走 \(i\) 步路径条数的奇偶性, 这个 ...

  2. mysql 简单进阶 ———— 重构查询[二]

    前言 简单整理一下重构查询. 正文 为什么我们需要重构查询,原因也很简单,那就是查询慢. 为什么会查询慢? 查询性能慢底下的最基本的原因是访问的数据太多. 某些查询不可避免地需要筛选大量的数据,但这并 ...

  3. 力扣577(MySQL)-员工奖金(简单)

    题目: 问题:选出所有 bonus < 1000 的员工的 name 及其 bonus.Employee 表单,empId 是这张表单的主关键字 Bonus 表单,empId 是这张表单的主关键 ...

  4. IDC:云效产品能力No.1,领跑中国DevOps市场

    简介: 近日,全球领先的专业市场调查机构国际数据公司(IDC)发布了<IDC MarketScape:中国 DevOps 平台市场厂商评估,2022>报告.此报告中对中国主流 DevOps ...

  5. 【USENIX ATC】支持异构GPU集群的超大规模模型的高效的分布式训练框架Whale

    简介: 高效大模型训练框架Whale(EPL)入选USENIX ATC 作者:张杰.贾贤艳 近日,阿里云机器学习PAI关于深度学习模型高效的分布式训练框架的论文< Whale: Efficien ...

  6. 阿里云开源业内首个应用多活项目 AppActive,与社区共建云原生容灾标准

    ​简介:继高可用架构团队的 Sentinel.Chaosblade 开源后,第三个重磅高可用产品:应用多活 AppActive 正式开源,形成高可用的三架马车,帮助企业构建稳定可靠的企业级生产系统,提 ...

  7. Apsara Stack 技术百科 | 标准化的云时代:一云多芯

    ​简介:随着今年云栖大会现场平头哥的自研云芯片倚天710发布,以及众多新兴厂商的芯片发布,将有越来越多的类型芯片进入到主流市场,"多芯"的架构将在数据中心中越来越常见,阿里云混合云 ...

  8. OpenKruise v0.9.0 版本发布:新增 Pod 重启、删除防护等重磅功能

    简介: OpenKruise 是阿里云开源的云原生应用自动化管理套件,也是当前托管在 Cloud Native Computing Foundation (CNCF) 下的 Sandbox 项目.它来 ...

  9. IDA动态调试快捷键

    1. F2下断点2. F7进入函数,F8单步调试,F9跳到下一个断点,F2下断点,G调到函数地址3. N重名4. g跳到地址和函数名5. u取消把函数汇编变成机器码6. c就是把机器码变成汇编7. F ...

  10. [FAQ] 快速上手 Final Cut Pro X 的入门教程

    FinalCutPro视频剪辑 基本操作教学,看下面的视频作为一个大致了解.另外遇到其它问题再针对性搜索解决即可. > 在线CF靶场 射击消除烦闷 Link:https://www.cnblog ...