一文教你在MindSpore中实现A2C算法训练
本文分享自华为云社区《MindSpore A2C 强化学习》,作者:irrational。
Advantage Actor-Critic (A2C)算法是一个强化学习算法,它结合了策略梯度(Actor)和价值函数(Critic)的方法。A2C算法在许多强化学习任务中表现优越,因为它能够利用价值函数来减少策略梯度的方差,同时直接优化策略。
A2C算法的核心思想
- Actor:根据当前策略选择动作。
- Critic:评估一个状态-动作对的值(通常是使用状态值函数或动作值函数)。
- 优势函数(Advantage Function):用来衡量某个动作相对于平均水平的好坏,通常定义为A(s,a)=Q(s,a)−V(s)。
A2C算法的伪代码
以下是A2C算法的伪代码:
Initialize policy network (actor) π with parameters θ
Initialize value network (critic) V with parameters w
Initialize learning rates α_θ for policy network and α_w for value network for each episode do
Initialize state s
while state s is not terminal do
# Actor: select action a according to the current policy π(a|s; θ)
a = select_action(s, θ) # Execute action a in the environment, observe reward r and next state s'
r, s' = environment.step(a) # Critic: compute the value of the current state V(s; w)
V_s = V(s, w) # Critic: compute the value of the next state V(s'; w)
V_s_prime = V(s', w) # Compute the TD error (δ)
δ = r + γ * V_s_prime - V_s # Critic: update the value network parameters w
w = w + α_w * δ * ∇_w V(s; w) # Compute the advantage function A(s, a)
A = δ # Actor: update the policy network parameters θ
θ = θ + α_θ * A * ∇_θ log π(a|s; θ) # Move to the next state
s = s'
end while
end for
解释
- 初始化:初始化策略网络(Actor)和价值网络(Critic)的参数,以及它们的学习率。
- 循环每个Episode:在每个Episode开始时,初始化状态。
- 选择动作:根据当前策略从Actor中选择动作。
- 执行动作:在环境中执行动作,并观察奖励和下一个状态。
- 计算状态值:用Critic评估当前状态和下一个状态的值。
- 计算TD误差:计算时序差分误差(Temporal Difference Error),它是当前奖励加上下一个状态的折扣值与当前状态值的差。
- 更新Critic:根据TD误差更新价值网络的参数。
- 计算优势函数:使用TD误差计算优势函数。
- 更新Actor:根据优势函数更新策略网络的参数。
- 更新状态:移动到下一个状态,重复上述步骤,直到Episode结束。
这个伪代码展示了A2C算法的核心步骤,实际实现中可能会有更多细节,如使用折扣因子γ、多个并行环境等。
代码如下:
import argparse from mindspore import context
from mindspore import dtype as mstype
from mindspore.communication import init from mindspore_rl.algorithm.a2c import config
from mindspore_rl.algorithm.a2c.a2c_session import A2CSession
from mindspore_rl.algorithm.a2c.a2c_trainer import A2CTrainer parser = argparse.ArgumentParser(description="MindSpore Reinforcement A2C")
parser.add_argument("--episode", type=int, default=10000, help="total episode numbers.")
parser.add_argument(
"--device_target",
type=str,
default="CPU",
choices=["CPU", "GPU", "Ascend", "Auto"],
help="Choose a devioptions.device_targece to run the ac example(Default: Auto).",
)
parser.add_argument(
"--precision_mode",
type=str,
default="fp32",
choices=["fp32", "fp16"],
help="Precision mode",
)
parser.add_argument(
"--env_yaml",
type=str,
default="../env_yaml/CartPole-v0.yaml",
help="Choose an environment yaml to update the a2c example(Default: CartPole-v0.yaml).",
)
parser.add_argument(
"--algo_yaml",
type=str,
default=None,
help="Choose an algo yaml to update the a2c example(Default: None).",
)
parser.add_argument(
"--enable_distribute",
type=bool,
default=False,
help="Train in distribute mode (Default: False).",
)
parser.add_argument(
"--worker_num",
type=int,
default=2,
help="Worker num (Default: 2).",
)
options, _ = parser.parse_known_args()
首先初始化参数,然后我这里用cpu运行:options.device_targe = “CPU”
episode=options.episode
"""Train a2c"""
if options.device_target != "Auto":
context.set_context(device_target=options.device_target)
if context.get_context("device_target") in ["CPU", "GPU"]:
context.set_context(enable_graph_kernel=True)
context.set_context(mode=context.GRAPH_MODE)
compute_type = (
mstype.float32 if options.precision_mode == "fp32" else mstype.float16
)
config.algorithm_config["policy_and_network"]["params"][
"compute_type"
] = compute_type
if compute_type == mstype.float16 and options.device_target != "Ascend":
raise ValueError("Fp16 mode is supported by Ascend backend.")
is_distribte = options.enable_distribute
if is_distribte:
init()
context.set_context(enable_graph_kernel=False)
config.deploy_config["worker_num"] = options.worker_num
a2c_session = A2CSession(options.env_yaml, options.algo_yaml, is_distribte)
设置上下文管理器
import sys
import time
from io import StringIO class RealTimeCaptureAndDisplayOutput(object):
def __init__(self):
self._original_stdout = sys.stdout
self._original_stderr = sys.stderr
self.captured_output = StringIO() def write(self, text):
self._original_stdout.write(text) # 实时打印
self.captured_output.write(text) # 保存到缓冲区 def flush(self):
self._original_stdout.flush()
self.captured_output.flush() def __enter__(self):
sys.stdout = self
sys.stderr = self
return self def __exit__(self, exc_type, exc_val, exc_tb):
sys.stdout = self._original_stdout
sys.stderr = self._original_stderr
episode=10
# dqn_session.run(class_type=DQNTrainer, episode=episode)
with RealTimeCaptureAndDisplayOutput() as captured_new:
a2c_session.run(class_type=A2CTrainer, episode=episode)
import re
import matplotlib.pyplot as plt # 原始输出
raw_output = captured_new.captured_output.getvalue() # 使用正则表达式从输出中提取loss和rewards
loss_pattern = r"loss=(\d+\.\d+)"
reward_pattern = r"running_reward=(\d+\.\d+)"
loss_values = [float(match.group(1)) for match in re.finditer(loss_pattern, raw_output)]
reward_values = [float(match.group(1)) for match in re.finditer(reward_pattern, raw_output)] # 绘制loss曲线
plt.plot(loss_values, label='Loss')
plt.xlabel('Episode')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.legend()
plt.show() # 绘制reward曲线
plt.plot(reward_values, label='Rewards')
plt.xlabel('Episode')
plt.ylabel('Rewards')
plt.title('Rewards Curve')
plt.legend()
plt.show()
展示结果:

下面我将详细解释你提供的 MindSpore A2C 算法训练配置参数的含义:
Actor 配置
'actor': {
'number': 1,
'type': mindspore_rl.algorithm.a2c.a2c.A2CActor,
'params': {
'collect_environment': PyFuncWrapper<
(_envs): GymEnvironment<>
>,
'eval_environment': PyFuncWrapper<
(_envs): GymEnvironment<>
>,
'replay_buffer': None,
'a2c_net': ActorCriticNet<
(common): Dense<input_channels=4, output_channels=128, has_bias=True>
(actor): Dense<input_channels=128, output_channels=2, has_bias=True>
(critic): Dense<input_channels=128, output_channels=1, has_bias=True>
(relu): LeakyReLU<>
>},
'policies': [],
'networks': ['a2c_net']
}
number: Actor 的实例数量,这里设置为1,表示使用一个 Actor 实例。type: Actor 的类型,这里使用mindspore_rl.algorithm.a2c.a2c.A2CActor。params: Actor 的参数配置。collect_environment和eval_environment: 使用PyFuncWrapper包装的GymEnvironment,用于数据收集和评估环境。replay_buffer: 设置为None,表示不使用经验回放缓冲区。a2c_net: Actor-Critic 网络,包含一个公共层、一个 Actor 层和一个 Critic 层,以及一个 Leaky ReLU 激活函数。
policies和networks: Actor 关联的策略和网络,这里主要是a2c_net。
Learner 配置
'learner': {
'number': 1,
'type': mindspore_rl.algorithm.a2c.a2c.A2CLearner,
'params': {
'gamma': 0.99,
'state_space_dim': 4,
'action_space_dim': 2,
'a2c_net': ActorCriticNet<
(common): Dense<input_channels=4, output_channels=128, has_bias=True>
(actor): Dense<input_channels=128, output_channels=2, has_bias=True>
(critic): Dense<input_channels=128, output_channels=1, has_bias=True>
(relu): LeakyReLU<>
>,
'a2c_net_train': TrainOneStepCell<
(network): Loss<
(a2c_net): ActorCriticNet<
(common): Dense<input_channels=4, output_channels=128, has_bias=True>
(actor): Dense<input_channels=128, output_channels=2, has_bias=True>
(critic): Dense<input_channels=128, output_channels=1, has_bias=True>
(relu): LeakyReLU<>
>
(smoothl1_loss): SmoothL1Loss<>
>
(optimizer): Adam<>
(grad_reducer): Identity<>
>
},
'networks': ['a2c_net_train', 'a2c_net']
}
number: Learner 的实例数量,这里设置为1,表示使用一个 Learner 实例。type: Learner 的类型,这里使用mindspore_rl.algorithm.a2c.a2c.A2CLearner。params: Learner 的参数配置。gamma: 折扣因子,用于未来奖励的折扣计算。state_space_dim: 状态空间的维度,这里为4。action_space_dim: 动作空间的维度,这里为2。a2c_net: Actor-Critic 网络定义,与 Actor 中相同。a2c_net_train: 用于训练的网络,包含损失函数(SmoothL1Loss)、优化器(Adam)和梯度缩减器(Identity)。
networks: Learner 关联的网络,包括a2c_net_train和a2c_net。
Policy and Network 配置
'policy_and_network': {
'type': mindspore_rl.algorithm.a2c.a2c.A2CPolicyAndNetwork,
'params': {
'lr': 0.01,
'state_space_dim': 4,
'action_space_dim': 2,
'hidden_size': 128,
'gamma': 0.99,
'compute_type': mindspore.float32,
'environment_config': {
'id': 'CartPole-v0',
'entry_point': 'gym.envs.classic_control:CartPoleEnv',
'reward_threshold': 195.0,
'nondeterministic': False,
'max_episode_steps': 200,
'_kwargs': {},
'_env_name': 'CartPole'
}
}
}
type: 策略和网络的类型,这里使用mindspore_rl.algorithm.a2c.a2c.A2CPolicyAndNetwork。params: 策略和网络的参数配置。lr: 学习率,这里为0.01。state_space_dim和action_space_dim: 状态和动作空间的维度。hidden_size: 隐藏层的大小,这里为128。gamma: 折扣因子。compute_type: 计算类型,这里为mindspore.float32。environment_config: 环境配置,包括环境 ID、入口、奖励阈值、最大步数等。
Collect Environment 配置
'collect_environment': {
'number': 1,
'type': mindspore_rl.environment.gym_environment.GymEnvironment,
'wrappers': [mindspore_rl.environment.pyfunc_wrapper.PyFuncWrapper],
'params': {
'GymEnvironment': {
'name': 'CartPole-v0',
'seed': 42
},
'name': 'CartPole-v0'
}
}
number: 环境实例数量,这里为1。type: 环境的类型,这里使用mindspore_rl.environment.gym_environment.GymEnvironment。wrappers: 环境使用的包装器,这里是PyFuncWrapper。params: 环境的参数配置,包括环境名称CartPole-v0和随机种子42。
Eval Environment 配置
'eval_environment': {
'number': 1,
'type': mindspore_rl.environment.gym_environment.GymEnvironment,
'wrappers': [mindspore_rl.environment.pyfunc_wrapper.PyFuncWrapper],
'params': {
'GymEnvironment': {
'name': 'CartPole-v0',
'seed': 42
},
'name': 'CartPole-v0'
}
}
- 配置与
collect_environment类似,用于评估模型性能。
总结一下,这些配置定义了 Actor-Critic 算法在 MindSpore 框架中的具体实现,包括 Actor 和 Learner 的设置、策略和网络的参数,以及训练和评估环境的配置。这个还是比较基础的。
一文教你在MindSpore中实现A2C算法训练的更多相关文章
- 带你学习MindSpore中算子使用方法
摘要:本文分享下MindSpore中算子的使用和遇到问题时的解决方法. 本文分享自华为云社区<[MindSpore易点通]算子使用问题与解决方法>,作者:chengxiaoli. 简介 算 ...
- Window10 上MindSpore(CPU)用LeNet网络训练MNIST
本文是在windows10上安装了CPU版本的Mindspore,并在mindspore的master分支基础上使用LeNet网络训练MNIST数据集,实践已训练成功,此文为记录过程中的出现问题: ( ...
- Java中的经典算法之冒泡排序(Bubble Sort)
Java中的经典算法之冒泡排序(Bubble Sort) 神话丿小王子的博客主页 原理:比较两个相邻的元素,将值大的元素交换至右端. 思路:依次比较相邻的两个数,将小数放在前面,大数放在后面.即在第一 ...
- 分布式数据库中的Paxos 算法
分布式数据库中的Paxos 算法 http://baike.baidu.com/link?url=ChmfvtXRZQl7X1VmRU6ypsmZ4b4MbQX1pelw_VenRLnFpq7rMvY ...
- Java中的查找算法之顺序查找(Sequential Search)
Java中的查找算法之顺序查找(Sequential Search) 神话丿小王子的博客主页 a) 原理:顺序查找就是按顺序从头到尾依次往下查找,找到数据,则提前结束查找,找不到便一直查找下去,直到数 ...
- Java中的经典算法之选择排序(SelectionSort)
Java中的经典算法之选择排序(SelectionSort) 神话丿小王子的博客主页 a) 原理:每一趟从待排序的记录中选出最小的元素,顺序放在已排好序的序列最后,直到全部记录排序完毕.也就是:每一趟 ...
- STL中的查找算法
STL中有很多算法,这些算法可以用到一个或多个STL容器(因为STL的一个设计思想是将算法和容器进行分离),也可以用到非容器序列比如数组中.众多算法中,查找算法是应用最为普遍的一类. 单个元素查找 1 ...
- opencv3中的机器学习算法之:EM算法
不同于其它的机器学习模型,EM算法是一种非监督的学习算法,它的输入数据事先不需要进行标注.相反,该算法从给定的样本集中,能计算出高斯混和参数的最大似然估计.也能得到每个样本对应的标注值,类似于kmea ...
- 在opencv3中的机器学习算法
在opencv3.0中,提供了一个ml.cpp的文件,这里面全是机器学习的算法,共提供了这么几种: 1.正态贝叶斯:normal Bayessian classifier 我已在另外一篇博文中介 ...
- Java中的排序算法(2)
Java中的排序算法(2) * 快速排序 * 快速排序使用分治法(Divide and conquer)策略来把一个序列(list)分为两个子序列(sub-lists). * 步骤为: * 1. 从数 ...
随机推荐
- 离线语音识别,vosk,离线流式实时静音噪声监测,支持多语言开发python c++ c# java等
#!/usr/bin/env python3 from vosk import Model, KaldiRecognizer, SetLogLevel import sys import os imp ...
- Linux_aarch64_head.S到main.c的环境建立
PS:要转载请注明出处,本人版权所有. PS: 这个只是基于<我自己>的理解, 如果和你的原则及想法相冲突,请谅解,勿喷. 环境说明 无 前言 最开始,我仅仅是对linux比较感兴 ...
- 力扣744(java&python)- 寻找比目标字母大的最小字母(简单)
题目: 给你一个排序后的字符列表 letters ,列表中只包含小写英文字母.另给出一个目标字母 target,请你寻找在这一有序列表里比目标字母大的最小字母. 在比较时,字母是依序循环出现的.举个例 ...
- 力扣13(java)-罗马数字转整数(简单)
题目: 罗马数字包含以下七种字符: I, V, X, L,C,D 和 M. 字符 数值I 1V 5X 10L 50C 100D 500M 1000例如, 罗马数字 2 写做 II ,即为两个并列的 1 ...
- 牛客网-SQL专项练习4
①向表evaluate的成绩列添加成绩,从表grade中的成绩一列提取记录,SQL语句为: INSERT INTO evaluate(grade.point) SELECT grade.point ...
- 力扣393(java)-UTF-8编码验证(中等)
题目: 给定一个表示数据的整数数组 data ,返回它是否为有效的 UTF-8 编码. UTF-8 中的一个字符可能的长度为 1 到 4 字节,遵循以下的规则: 对于 1 字节 的字符,字节的第一位设 ...
- [GPT] 用 document.querySelector('.xxx') 选择下级的第二个 div 要怎么写
要选择类名为 .xxx 的元素下的第二个子<div>元素,可以将 querySelectorAll()方法与CSS选择器一起使用. 以下是一个示例: const secondChild ...
- 开发日志:Kylin麒麟操作系统部署ASP.NET CORE
需求场景: 我需要部署的项目是在Windows上开发的,目标框架为.net core 6.0 因此我们需要先在kylin上部署项目运行所需要的环境. 借助百度词条,先看看Kylin是什么: 服务器资源 ...
- vim 使用black 格式化python代码
vim 使用black 格式化代码 github black 的github https://github.com/psf/black 安装 pip3 install black 使用 black f ...
- DNS(7) -- 智能DNS实现
目录 1. 智能DNS 1.1 智能DNS概述 1.2 ACL控制列表 1.3 智能DNS实现 1.3.1 bind-view功能 1.3.2 智能DNS场景实现 1.3.3 生产场景配置示例 1. ...