其实KL散度在这个游戏里的作用不大,游戏的action比较简单,不像LM里的action是一个很大的向量,可以直接用surr1,最大化surr1,实验测试确实是这样,而且KL的系数不能给太大,否则惩罚力度太大,action model 和ref model产生的action其实分布的差距并不太大

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pygame
import sys
from collections import deque # 定义策略网络
class PolicyNetwork(nn.Module):
def __init__(self):
super(PolicyNetwork, self).__init__()
self.fc = nn.Sequential(
nn.Linear(4, 2),
nn.Tanh(),
nn.Linear(2, 2), # CartPole的动作空间为2
nn.Softmax(dim=-1)
) def forward(self, x):
return self.fc(x) # 定义值网络
class ValueNetwork(nn.Module):
def __init__(self):
super(ValueNetwork, self).__init__()
self.fc = nn.Sequential(
nn.Linear(4, 2),
nn.Tanh(),
nn.Linear(2, 1)
) def forward(self, x):
return self.fc(x) # 经验回放缓冲区
class RolloutBuffer:
def __init__(self):
self.states = []
self.actions = []
self.rewards = []
self.dones = []
self.log_probs = [] def store(self, state, action, reward, done, log_prob):
self.states.append(state)
self.actions.append(action)
self.rewards.append(reward)
self.dones.append(done)
self.log_probs.append(log_prob) def clear(self):
self.states = []
self.actions = []
self.rewards = []
self.dones = []
self.log_probs = [] def get_batch(self):
return (
torch.tensor(self.states, dtype=torch.float),
torch.tensor(self.actions, dtype=torch.long),
torch.tensor(self.rewards, dtype=torch.float),
torch.tensor(self.dones, dtype=torch.bool),
torch.tensor(self.log_probs, dtype=torch.float)
) # PPO更新函数
def ppo_update(policy_net, value_net, optimizer_policy, optimizer_value, buffer, epochs=100, gamma=0.99, clip_param=0.2):
states, actions, rewards, dones, old_log_probs = buffer.get_batch()
returns = []
advantages = []
G = 0
adv = 0
dones = dones.to(torch.int)
# print(dones)
for reward, done, value in zip(reversed(rewards), reversed(dones), reversed(value_net(states))):
if done:
G = 0
adv = 0
G = reward + gamma * G #蒙特卡洛回溯G值
delta = reward + gamma * value.item() * (1 - done) - value.item() #TD差分
# adv = delta + gamma * 0.95 * adv * (1 - done) #
adv = delta + adv*(1-done)
returns.insert(0, G)
advantages.insert(0, adv) returns = torch.tensor(returns, dtype=torch.float) #价值
advantages = torch.tensor(advantages, dtype=torch.float)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) #add baseline for _ in range(epochs):
action_probs = policy_net(states)
dist = torch.distributions.Categorical(action_probs)
new_log_probs = dist.log_prob(actions)
ratio = (new_log_probs - old_log_probs).exp() KL = new_log_probs.exp()*(new_log_probs - old_log_probs).mean() #KL散度 p*log(p/p')
#下面三行是核心
surr1 = ratio * advantages PPO1,PPO2 = True,False
# print(surr1,KL*500)
if PPO1 == True:
actor_loss = -(surr1 - KL).mean() if PPO2 == True:
surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantages
actor_loss = -torch.min(surr1, surr2).mean() optimizer_policy.zero_grad()
actor_loss.backward()
optimizer_policy.step() value_loss = (returns - value_net(states)).pow(2).mean() optimizer_value.zero_grad()
value_loss.backward()
optimizer_value.step() # 初始化环境和模型
env = gym.make('CartPole-v1')
policy_net = PolicyNetwork()
value_net = ValueNetwork()
optimizer_policy = optim.Adam(policy_net.parameters(), lr=3e-4)
optimizer_value = optim.Adam(value_net.parameters(), lr=1e-3)
buffer = RolloutBuffer() # Pygame初始化
pygame.init()
screen = pygame.display.set_mode((600, 400))
clock = pygame.time.Clock() draw_on = False
# 训练循环
state = env.reset()
for episode in range(10000): # 训练轮次
done = False
state = state[0]
step= 0
while not done:
step+=1
state_tensor = torch.FloatTensor(state).unsqueeze(0)
action_probs = policy_net(state_tensor) #旧policy推理数据
dist = torch.distributions.Categorical(action_probs)
action = dist.sample()
log_prob = dist.log_prob(action) next_state, reward, done, _ ,_ = env.step(action.item())
buffer.store(state, action.item(), reward, done, log_prob) state = next_state # 实时显示
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
sys.exit() if draw_on:
# 清屏并重新绘制
screen.fill((0, 0, 0))
cart_x = int(state[0] * 100 + 300) # 位置转换为屏幕坐标
pygame.draw.rect(screen, (0, 128, 255), (cart_x, 300, 50, 30))
pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * np.sin(state[2])), 300 - int(50 * np.cos(state[2]))), 5)
pygame.display.flip()
clock.tick(60) if step >2000:
draw_on = True
ppo_update(policy_net, value_net, optimizer_policy, optimizer_value, buffer)
buffer.clear()
state = env.reset()
print(f'Episode {episode} completed , reward: {step}.') # 结束训练
env.close()
pygame.quit()

效果:

PPO-KL散度近端策略优化玩cartpole游戏的更多相关文章

  1. TensorFlow利用A3C算法训练智能体玩CartPole游戏

    本教程讲解如何使用深度强化学习训练一个可以在 CartPole 游戏中获胜的模型.研究人员使用 tf.keras.OpenAI 训练了一个使用「异步优势动作评价」(Asynchronous Advan ...

  2. DRL 教程 | 如何保持运动小车上的旗杆屹立不倒?TensorFlow利用A3C算法训练智能体玩CartPole游戏

    本教程讲解如何使用深度强化学习训练一个可以在 CartPole 游戏中获胜的模型.研究人员使用 tf.keras.OpenAI 训练了一个使用「异步优势动作评价」(Asynchronous Advan ...

  3. KL散度的理解(GAN网络的优化)

    原文地址Count Bayesie 这篇文章是博客Count Bayesie上的文章Kullback-Leibler Divergence Explained 的学习笔记,原文对 KL散度 的概念诠释 ...

  4. (转)KL散度的理解

    KL散度(KL divergence) 全称:Kullback-Leibler Divergence. 用途:比较两个概率分布的接近程度.在统计应用中,我们经常需要用一个简单的,近似的概率分布 f * ...

  5. PRML读书会第十章 Approximate Inference(近似推断,变分推断,KL散度,平均场, Mean Field )

    主讲人 戴玮 (新浪微博: @戴玮_CASIA) Wilbur_中博(1954123) 20:02:04 我们在前面看到,概率推断的核心任务就是计算某分布下的某个函数的期望.或者计算边缘概率分布.条件 ...

  6. 非负矩阵分解(1):准则函数及KL散度

    作者:桂. 时间:2017-04-06  12:29:26 链接:http://www.cnblogs.com/xingshansi/p/6672908.html 声明:欢迎被转载,不过记得注明出处哦 ...

  7. 深度学习中交叉熵和KL散度和最大似然估计之间的关系

    机器学习的面试题中经常会被问到交叉熵(cross entropy)和最大似然估计(MLE)或者KL散度有什么关系,查了一些资料发现优化这3个东西其实是等价的. 熵和交叉熵 提到交叉熵就需要了解下信息论 ...

  8. 机器学习:Kullback-Leibler Divergence (KL 散度)

    今天,我们介绍机器学习里非常常用的一个概念,KL 散度,这是一个用来衡量两个概率分布的相似性的一个度量指标.我们知道,现实世界里的任何观察都可以看成表示成信息和数据,一般来说,我们无法获取数据的总体, ...

  9. 相对熵(KL散度)

    https://blog.csdn.net/weixinhum/article/details/85064685 上一篇文章我们简单介绍了信息熵的概念,知道了信息熵可以表达数据的信息量大小,是信息处理 ...

  10. ELBO 与 KL散度

    浅谈KL散度 一.第一种理解 相对熵(relative entropy)又称为KL散度(Kullback–Leibler divergence,简称KLD),信息散度(information dive ...

随机推荐

  1. #扫描线,线段树#nssl 1459 空间复杂度

    分析 由于\(k\leq 10\)所以考虑用总方案减去经过两个差的绝对值\(\leq k\)的点的路径数 分类讨论一下发现要处理祖先关系和其它关系两种情况,考虑怎么去重,可以将这些答案看作一个个矩形, ...

  2. #前缀和,后缀和#洛谷 4280 [AHOI2008]逆序对

    题目传送门 分析 首先填的数字单调不降,感性理解 那可以维护\([a_1\sim a_{i-1}]\)的\(cnt\)后缀和以及 \([a_{i+1}\sim a_n]\)的\(cnt\)前缀和,那可 ...

  3. #平衡树#洛谷 1110 [ZJOI2007]报表统计

    题目 分析 最小值只需要开两棵平衡树,一棵维护所有元素,一棵维护相邻最小值, 对于全局最小值,对于每次插入查找前驱后继更新最小值即可, 相邻最小值,对于每个原数列的数维护它的开头和结尾是什么数, 然后 ...

  4. #模拟#洛谷 2327 [SCOI2005]扫雷

    题目 分析 考虑最多只有两种情况,因为确定一个位置其它位置随即也能确定, 那么指定第一个位置有没有雷然后判断一下后面推出的雷数是否为0或1,不是显然不行 代码 #include <cstdio& ...

  5. 前端常用库 CDN

    jQuery 链接: v1.9.1:https://i.mazey.net/lib/jquery/1.9.1/jquery.min.js v2.1.1:https://i.mazey.net/lib/ ...

  6. java中的类型擦除type erasure

    目录 简介 举个例子 原因 解决办法 总结 简介 泛型是java从JDK 5开始引入的新特性,泛型的引入可以让我们在代码编译的时候就强制检查传入的类型,从而提升了程序的健壮度. 泛型可以用在类和接口上 ...

  7. 一键部署openGauss2.0.1 CentOS 7.6

    一键部署 openGauss2.0.1[CentOS 7.6] 本文档目的是为了帮助高校学生提供基于 CentOS7.6 操作系统,实现 openGauss 数据库一键式安装的脚本. 该脚本执行成功后 ...

  8. Sample上新,从API 8开始支持!速来拿走

    原文:https://mp.weixin.qq.com/s/TxUOSXySZRwQaECenxt-Og ,点击链接查看更多技术内容. 搭载API 8的新SDK已经发布.围绕着新SDK,官方贴心地输出 ...

  9. Avalonia 中的样式和控件主题

    在 Avalonia 中,样式是定义控件外观的一种方式,而控件主题则是一组样式和资源,用于定义应用程序的整体外观和感觉.本文将深入探讨这些概念,并提供示例代码以帮助您更好地理解它们. 样式是什么? 样 ...

  10. 力扣25(java&python)-K 个一组翻转链表(困难)

    题目: 给你链表的头节点 head ,每 k 个节点一组进行翻转,请你返回修改后的链表. k 是一个正整数,它的值小于或等于链表的长度.如果节点总数不是 k 的整数倍,那么请将最后剩余的节点保持原有顺 ...