cartpole游戏,车上顶着一个自由摆动的杆子,实现杆子的平衡,杆子每次倒向一端车就开始移动让杆子保持动态直立的状态,策略函数使用一个两层的简单神经网络,输入状态有4个,车位置,车速度,杆角度,杆速度,输出action为左移动或右移动,输入状态发现至少要给3个才能稳定一会儿,给2个完全学不明白,给4个能学到很稳定的policy

策略梯度实现代码,使用torch实现一个简单的神经网络

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import pygame
import sys
from collections import deque
import numpy as np # 策略网络定义
class PolicyNetwork(nn.Module):
def __init__(self):
super(PolicyNetwork, self).__init__()
self.fc = nn.Sequential(
nn.Linear(4, 10), # 4个状态输入,128个隐藏单元
nn.Tanh(),
nn.Linear(10, 2), # 输出2个动作的概率
nn.Softmax(dim=-1)
) def forward(self, x):
# print(x) 车位置 车速度 杆角度 杆速度
selected_values = x[:, [0,1,2,3]] #只使用车位置和杆角度
return self.fc(selected_values) # 训练函数
def train(policy_net, optimizer, trajectories):
policy_net.zero_grad()
loss = 0
print(trajectories[0])
for trajectory in trajectories: # if trajectory["returns"] > 90:
# returns = torch.tensor(trajectory["returns"]).float()
# else:
returns = torch.tensor(trajectory["returns"]).float() - torch.tensor(trajectory["step_mean_reward"]).float()
# print(f"获得奖励{returns}")
log_probs = trajectory["log_prob"]
loss += -(log_probs * returns).sum() # 计算策略梯度损失
loss.backward()
optimizer.step()
return loss.item() # 主函数
def main():
env = gym.make('CartPole-v1')
policy_net = PolicyNetwork()
optimizer = optim.Adam(policy_net.parameters(), lr=0.01) print(env.action_space)
print(env.observation_space)
pygame.init()
screen = pygame.display.set_mode((600, 400))
clock = pygame.time.Clock() rewards_one_episode= []
for episode in range(10000): state = env.reset()
done = False
trajectories = []
state = state[0]
step = 0
torch.save(policy_net, 'policy_net_full.pth')
while not done:
state_tensor = torch.tensor(state).float().unsqueeze(0)
probs = policy_net(state_tensor)
action = torch.distributions.Categorical(probs).sample().item()
log_prob = torch.log(probs.squeeze(0)[action])
next_state, reward, done, _,_ = env.step(action) # print(episode)
trajectories.append({"state": state, "action": action, "reward": reward, "log_prob": log_prob})
state = next_state for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
sys.exit()
step +=1 # 绘制环境状态
if rewards_one_episode and rewards_one_episode[-1] >99:
screen.fill((255, 255, 255))
cart_x = int(state[0] * 100 + 300)
pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30))
# print(state)
pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 2)
pygame.display.flip()
clock.tick(200) print(f"第{episode}回合",f"运行{step}步后挂了")
# 为策略梯度计算累积回报
returns = 0 for traj in reversed(trajectories):
returns = traj["reward"] + 0.99 * returns
traj["returns"] = returns
if rewards_one_episode:
# print(rewards_one_episode[:10])
traj["step_mean_reward"] = np.mean(rewards_one_episode[-10:])
else:
traj["step_mean_reward"] = 0
rewards_one_episode.append(returns)
# print(rewards_one_episode[:10])
train(policy_net, optimizer, trajectories) def play(): env = gym.make('CartPole-v1')
policy_net = PolicyNetwork()
pygame.init()
screen = pygame.display.set_mode((600, 400))
clock = pygame.time.Clock() state = env.reset()
done = False
trajectories = deque()
state = state[0]
step = 0
policy_net = torch.load('policy_net_full.pth')
while not done:
state_tensor = torch.tensor(state).float().unsqueeze(0)
probs = policy_net(state_tensor)
action = torch.distributions.Categorical(probs).sample().item()
log_prob = torch.log(probs.squeeze(0)[action])
next_state, reward, done, _,_ = env.step(action) # print(episode)
trajectories.append({"state": state, "action": action, "reward": reward, "log_prob": log_prob})
state = next_state for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
sys.exit() # 绘制环境状态
screen.fill((255, 255, 255))
cart_x = int(state[0] * 100 + 300)
pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30))
# print(state)
pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 2)
pygame.display.flip()
clock.tick(60)
step +=1 print(f"运行{step}步后挂了") if __name__ == '__main__':
main() #训练
# play() #推理

  运行效果,训练过程不是很稳定,有时候学很多轮次也学不明白,有时侯只需要几十次就可以学明白了

策略梯度玩 cartpole 游戏,强化学习代替PID算法控制平衡杆的更多相关文章

  1. 策略梯度训练cartpole小游戏

    我原来已经安装了anaconda,在此基础上进入cmd进行pip install tensorflow和pip install gym就可以了. 在win10的pycharm做的. policy_gr ...

  2. TensorFlow利用A3C算法训练智能体玩CartPole游戏

    本教程讲解如何使用深度强化学习训练一个可以在 CartPole 游戏中获胜的模型.研究人员使用 tf.keras.OpenAI 训练了一个使用「异步优势动作评价」(Asynchronous Advan ...

  3. DRL 教程 | 如何保持运动小车上的旗杆屹立不倒?TensorFlow利用A3C算法训练智能体玩CartPole游戏

    本教程讲解如何使用深度强化学习训练一个可以在 CartPole 游戏中获胜的模型.研究人员使用 tf.keras.OpenAI 训练了一个使用「异步优势动作评价」(Asynchronous Advan ...

  4. 【强化学习】DQN 算法改进

    DQN 算法改进 (一)Dueling DQN Dueling DQN 是一种基于 DQN 的改进算法.主要突破点:利用模型结构将值函数表示成更加细致的形式,这使得模型能够拥有更好的表现.下面给出公式 ...

  5. 【算法总结】强化学习部分基础算法总结(Q-learning DQN PG AC DDPG TD3)

    总结回顾一下近期学习的RL算法,并给部分实现算法整理了流程图.贴了代码. 1. value-based 基于价值的算法 基于价值算法是通过对agent所属的environment的状态或者状态动作对进 ...

  6. 强化学习(十七) 基于模型的强化学习与Dyna算法框架

    在前面我们讨论了基于价值的强化学习(Value Based RL)和基于策略的强化学习模型(Policy Based RL),本篇我们讨论最后一种强化学习流派,基于模型的强化学习(Model Base ...

  7. 强化学习-时序差分算法(TD)和SARAS法

    1. 前言 我们前面介绍了第一个Model Free的模型蒙特卡洛算法.蒙特卡罗法在估计价值时使用了完整序列的长期回报.而且蒙特卡洛法有较大的方差,模型不是很稳定.本节我们介绍时序差分法,时序差分法不 ...

  8. 强化学习8-时序差分控制离线算法Q-Learning

    Q-Learning和Sarsa一样是基于时序差分的控制算法,那两者有什么区别呢? 这里已经必须引入新的概念 时序差分控制算法的分类:在线和离线 在线控制算法:一直使用一个策略选择动作和更新价值函数, ...

  9. 强化学习读书笔记 - 13 - 策略梯度方法(Policy Gradient Methods)

    强化学习读书笔记 - 13 - 策略梯度方法(Policy Gradient Methods) 学习笔记: Reinforcement Learning: An Introduction, Richa ...

  10. 基于Keras的OpenAI-gym强化学习的车杆/FlappyBird游戏

    强化学习 课程:Q-Learning强化学习(李宏毅).深度强化学习 强化学习是一种允许你创造能从环境中交互学习的AI Agent的机器学习算法,其通过试错来学习.如上图所示,大脑代表AI Agent ...

随机推荐

  1. redis的延迟双删策略

    1,redis数据为什么会存在和数据库数据不一致的问题 在多线程并发情况下,假设有两个数据库修改请求,为保证数据库与redis的数据一致性,修改请求的实现中需要修改数据库后,级联修改redis中的数据 ...

  2. Mysql Order 排序的时候占用很长时间解决思路

    MySQL中的连表查询(JOIN)在进行ORDER BY排序时可能会变得很慢,尤其是当处理大量数据时.以下是一些优化策略,可以帮助减少排序操作的时间: 索引优化: 确保参与排序的列上有索引.如果排序的 ...

  3. 【中秋国庆不断更】OpenHarmony多态样式stateStyles使用场景

    @Styles和@Extend仅仅应用于静态页面的样式复用,stateStyles可以依据组件的内部状态的不同,快速设置不同样式.这就是我们本章要介绍的内容stateStyles(又称为:多态样式). ...

  4. ssm 创建bean的三种方式和spring依赖注入的三种方式

    <!--创建bean的第一种方式:使用默认无参构造函数 在默认情况下: 它会根据默认无参构造函数来创建类对象.如果 bean 中没有默认无参构造函数,将会创建失败--> <bean ...

  5. MySQL的下载、安装和配置

    ​ 一.MySQL的下载 下载地址:http://dev.mysql.com/downloads/mysql 进入下载页面,选择所需版本,这里示范MySQL8.0 图一 选择版本,下载MSI(软件安装 ...

  6. 如何在HarmonyOS对数据库进行备份,恢复与加密

    数据库备份与恢复 场景介绍 当应用在处理一项重要的操作,显然是不能被打断的.例如:写入多个表关联的事务.此时,每个表的写入都是单独的,但是表与表之间的事务关联性不能被分割. 如果操作的过程中出现问题, ...

  7. Linux0.12内核源码解读(2)-Bootsect.S

    大家好,我是呼噜噜,在上一篇文章聊聊x86计算机启动发生的事?我们了解了x86计算机启动过程,MBR.0x7c00是什么?其中当bios引导结束后,操作系统接过计算机的控制权后,发生了哪些事?本文将揭 ...

  8. 树模型--ID3算法

    基于信息增益(Information Gain)的ID3算法 ID3算法的核心是在数据集上应用信息增益准则来进行特征选择,以此递归的构建决策树,以信息熵和信息增益为衡量标准,从而实现对数据的归纳分类. ...

  9. CentOS-6.5快速搭建HTTP服务器和仅供授权用户登陆的FTP服务器

    CentOS-6.5快速搭建HTTP服务器和仅供授权用户登陆的FTP服务器 (2014-01-09 21:29:31) 转载▼ 标签: linux centos 服务器 http vsftp 分类:& ...

  10. 重新整理.net core 计1400篇[七] (.net core 中的依赖注入)

    前言 请阅读第六篇,对于理解.net core 中的依赖注入很关键. 和我们上一篇不同的是,.net core服务注入保存在IServiceCollection 中,而将集合创建的依赖注入容器体现为I ...