强化学习笔记之【SAC算法】


前言:

本文为强化学习笔记第四篇,第一篇讲的是Q-learning和DQN,第二篇DDPG,第三篇TD3

TD3比DDPG少了一个target_actor网络,其它地方有点小改动

CSDN主页:https://blog.csdn.net/rvdgdsva

博客园主页:https://www.cnblogs.com/hassle


STAND ALONE COMPLEX = S . A . C

首先,我们需要明确,Q-learning算法发展成DQN算法,DQN算法发展成为DDPG算法,而DDPG算法发展成TD3算法,TD3算法发展成SAC算法

Soft Actor-Critic (SAC) 是一种基于策略梯度的深度强化学习算法,它具有最大化奖励与最大化熵(探索性)的双重目标。SAC 通过引入熵正则项,使策略在决策时具有更大的随机性,从而提高探索能力。

一、SAC算法

OK,先用伪代码让你们感受一下SAC算法

# 定义 SAC 超参数
alpha = 0.2 # 熵正则项系数
gamma = 0.99 # 折扣因子
tau = 0.005 # 目标网络软更新参数
lr = 3e-4 # 学习率 # 初始化 Actor、Critic、Target Critic 网络和优化器
actor = ActorNetwork() # 策略网络 π(s)
critic1 = CriticNetwork() # 第一个 Q 网络 Q1(s, a)
critic2 = CriticNetwork() # 第二个 Q 网络 Q2(s, a)
target_critic1 = CriticNetwork() # 目标 Q 网络 1
target_critic2 = CriticNetwork() # 目标 Q 网络 2 # 将目标 Q 网络的参数设置为与 Critic 网络相同
target_critic1.load_state_dict(critic1.state_dict())
target_critic2.load_state_dict(critic2.state_dict()) # 初始化优化器
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=lr)
critic1_optimizer = torch.optim.Adam(critic1.parameters(), lr=lr)
critic2_optimizer = torch.optim.Adam(critic2.parameters(), lr=lr) # 经验回放池(Replay Buffer)
replay_buffer = ReplayBuffer() # SAC 训练循环
for each iteration:
# Step 1: 从 Replay Buffer 中采样一个批次 (state, action, reward, next_state)
batch = replay_buffer.sample()
state, action, reward, next_state, done = batch # Step 2: 计算目标 Q 值 (y)
with torch.no_grad():
# 从 Actor 网络中获取 next_state 的下一个动作
next_action, next_log_prob = actor.sample(next_state) # 目标 Q 值的计算:使用目标 Q 网络的最小值 + 熵项
target_q1_value = target_critic1(next_state, next_action)
target_q2_value = target_critic2(next_state, next_action)
min_target_q_value = torch.min(target_q1_value, target_q2_value) # 目标 Q 值 y = r + γ * (最小目标 Q 值 - α * next_log_prob)
target_q_value = reward + gamma * (1 - done) * (min_target_q_value - alpha * next_log_prob) # Step 3: 更新 Critic 网络
# Critic 1 损失
current_q1_value = critic1(state, action)
critic1_loss = F.mse_loss(current_q1_value, target_q_value) # Critic 2 损失
current_q2_value = critic2(state, action)
critic2_loss = F.mse_loss(current_q2_value, target_q_value) # 反向传播并更新 Critic 网络参数
critic1_optimizer.zero_grad()
critic1_loss.backward()
critic1_optimizer.step() critic2_optimizer.zero_grad()
critic2_loss.backward()
critic2_optimizer.step() # Step 4: 更新 Actor 网络
# 通过 Actor 网络生成新的动作及其 log 概率
new_action, log_prob = actor.sample(state) # 计算 Actor 的目标损失:L = α * log_prob - Q1(s, π(s))
q1_value = critic1(state, new_action)
actor_loss = (alpha * log_prob - q1_value).mean() # 反向传播并更新 Actor 网络参数
actor_optimizer.zero_grad()
actor_loss.backward()
actor_optimizer.step() # Step 5: 软更新目标 Q 网络参数
with torch.no_grad():
for param, target_param in zip(critic1.parameters(), target_critic1.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) for param, target_param in zip(critic2.parameters(), target_critic2.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

二、SAC算法Latex解释

1、初始化 Actor、Critic1、Critic2、TargetCritic1 、TargetCritic2 网络

2、Buffer中采样 (state, action, reward, next_state)

3、Actor 输入 next_state 对应输出 next_action 和 next_log_prob

4、Actor 输入 state 对应输出 new_action 和 log_prob

5、Critic1 和 Critic2 分别输入next_state 和 next_action 取其中较小输出经熵正则计算得 target_q_value

6、使用 MSE_loss(Critic1(state, action), target_q_value) 更新 Critic1

7、使用 MSE_loss(Critic2(state, action), target_q_value) 更新 Critic2

8、使用 (alpha * log_prob - critic1(state, new_action)).mean() 更新 Actor


三、SAC五大网络和模块

SAC 算法 中,Actor、Critic1、Critic2、Target Critic1 和 Target Critic2 网络是核心模块,它们分别用于输出动作、评估状态-动作对的价值,并通过目标网络进行稳定的更新。

3.1 Actor 网络

Actor 网络用于在给定状态下输出一个高斯分布的均值和标准差(即策略)。它是通过神经网络近似的随机策略。用于选择动作。

import torch
import torch.nn as nn class ActorNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(ActorNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, 256)
self.fc2 = nn.Linear(256, 256)
self.mean_layer = nn.Linear(256, action_dim) # 输出动作的均值
self.log_std_layer = nn.Linear(256, action_dim) # 输出动作的log标准差 def forward(self, state):
x = torch.relu(self.fc1(state))
x = torch.relu(self.fc2(x))
mean = self.mean_layer(x) # 输出动作均值
log_std = self.log_std_layer(x) # 输出 log 标准差
log_std = torch.clamp(log_std, min=-20, max=2) # 限制标准差范围
return mean, log_std def sample(self, state):
mean, log_std = self.forward(state)
std = torch.exp(log_std) # 将 log 标准差转为标准差
normal = torch.distributions.Normal(mean, std)
action = normal.rsample() # 通过重参数化技巧进行采样
log_prob = normal.log_prob(action).sum(-1) # 计算 log 概率
return action, log_prob

3.2 Critic1 和 Critic2 网络

Critic 网络用于计算状态-动作对的 Q 值,SAC 使用两个 Critic 网络(Critic1 和 Critic2)来缓解 Q 值的过估计问题。

class CriticNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(CriticNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, 256)
self.fc2 = nn.Linear(256, 256)
self.q_value_layer = nn.Linear(256, 1) # 输出 Q 值 def forward(self, state, action):
x = torch.cat([state, action], dim=-1) # 将 state 和 action 作为输入
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
q_value = self.q_value_layer(x) # 输出 Q 值
return q_value

3.3 Target Critic1 和 Target Critic2 网络

Target Critic 网络的结构与 Critic 网络相同,用于稳定 Q 值更新。它们通过软更新(即在每次训练后慢慢接近 Critic 网络的参数)来保持训练的稳定性。

class TargetCriticNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(TargetCriticNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, 256)
self.fc2 = nn.Linear(256, 256)
self.q_value_layer = nn.Linear(256, 1) # 输出 Q 值 def forward(self, state, action):
x = torch.cat([state, action], dim=-1) # 将 state 和 action 作为输入
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
q_value = self.q_value_layer(x) # 输出 Q 值
return q_value

3.4 软更新模块

在 SAC 中,目标网络会通过软更新逐渐逼近 Critic 网络的参数。每次更新后,目标网络参数会按照 ττ 的比例向 Critic 网络的参数靠拢。

def soft_update(critic, target_critic, tau=0.005):
for param, target_param in zip(critic.parameters(), target_critic.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

3.5 总结

  1. 初始化网络和参数:

    • Actor 网络:用于选择动作。
    • Critic 1 和 Critic 2 网络:用于估计 Q 值。
    • Target Critic 1 和 Target Critic 2:与 Critic 网络架构相同,用于生成更稳定的目标 Q 值。
  2. 目标 Q 值计算:
    • 使用目标网络计算下一状态下的 Q 值。
    • 取两个 Q 网络输出的最小值,防止 Q 值的过估计。
    • 引入熵正则项,计算公式:$$y=r+\gamma\cdot\min(Q_1,Q_2)-\alpha\cdot\log\pi(a|s)$$
  3. 更新 Critic 网络:
    • 最小化目标 Q 值与当前 Q 值的均方误差 (MSE)。
  4. 更新 Actor 网络:
    • 最大化目标损失:$$L=\alpha\cdot\log\pi(a|s)-Q_1(s,\pi(s))$$,即在保证探索的情况下选择高价值动作。
  5. 软更新目标网络:
    • 软更新目标 Q 网络参数,使得目标网络参数缓慢向当前网络靠近,避免振荡。

强化学习笔记之【SAC算法】的更多相关文章

  1. 强化学习-学习笔记7 | Sarsa算法原理与推导

    Sarsa算法 是 TD算法的一种,之前没有严谨推导过 TD 算法,这一篇就来从数学的角度推导一下 Sarsa 算法.注意,这部分属于 TD算法的延申. 7. Sarsa算法 7.1 推导 TD ta ...

  2. GMM高斯混合模型学习笔记(EM算法求解)

    提出混合模型主要是为了能更好地近似一些较复杂的样本分布,通过不断添加component个数,能够随意地逼近不论什么连续的概率分布.所以我们觉得不论什么样本分布都能够用混合模型来建模.由于高斯函数具有一 ...

  3. CS229 - MachineLearning - 12 强化学习笔记

    Ng的机器学习课,课程资源:cs229-课件    网易公开课-视频 问题数学模型: 马尔科夫过程五元组{S.a.Psa.γ.R},分别对应 {状态.行为.状态s下做出a行为的概率.常数.回报}. 一 ...

  4. 【学习笔记】 Adaboost算法

    前言 之前的学习中也有好几次尝试过学习该算法,但是都无功而返,不仅仅是因为该算法各大博主.大牛的描述都比较晦涩难懂,同时我自己学习过程中也心浮气躁,不能专心. 现如今决定一口气肝到底,这样我明天就可以 ...

  5. 挑子学习笔记:DBSCAN算法的python实现

    转载请标明出处:https://www.cnblogs.com/tiaozistudy/p/dbscan_algorithm.html DBSCAN(Density-Based Spatial Clu ...

  6. 普通平衡树学习笔记之Splay算法

    前言 今天不容易有一天的自由学习时间,当然要用来"学习".在此记录一下今天学到的最基础的平衡树. 定义 平衡树是二叉搜索树和堆合并构成的数据结构,它是一 棵空树或它的左右两个子树的 ...

  7. 【算法学习笔记】Meissel-Lehmer 算法 (亚线性时间找出素数个数)

    「Meissel-Lehmer 算法」是一种能在亚线性时间复杂度内求出 \(1\sim n\) 内质数个数的一种算法. 在看素数相关论文时发现了这个算法,论文链接:Here. 算法的细节来自 OI w ...

  8. python学习笔记(MD5算法)

    博主最近进度停滞了 对web开发理解欠缺好多内容 今天整理下MD5算法,这个涉及到mysql数据库存储用户表密码字段的时候 一般是带有加密的 # -*- coding: utf-8 -*- impor ...

  9. 【学习笔记】分类算法-k近邻算法

    k-近邻算法采用测量不同特征值之间的距离来进行分类. 优点:精度高.对异常值不敏感.无数据输入假定 缺点:计算复杂度高.空间复杂度高 使用数据范围:数值型和标称型 用例子来理解k-近邻算法 电影可以按 ...

  10. Web安全学习笔记之DES算法实例详解

    转自http://www.hankcs.com/security/des-algorithm-illustrated.html 译自J. Orlin Grabbe的名作<DES Algorith ...

随机推荐

  1. 【MongoDB】Re03 索引

    MongoDB的索引种类 单属性索引 MongoDB支持在文档的单个字段上创建用户定义的升序/降序索引,称为单字段索引(Single Field Index). 对于单个字段索引和排序操作,索引键的排 ...

  2. 全地形人形机器人(humanoid)是否只能进行短距视野感知呢 —— 实时地形感知

    相关: https://capital.lenovo.com/news/detail/id/924/s/1.html 常见的人形机器人都是测试其手臂灵活度为主,但是近日看到一款以全地形步态行走为主的机 ...

  3. baselines库中atari_wrappers.py中的环境包装器的顺序问题

    如题: 在baselines中对atari游戏环境进行包装的代码在atari_wrappers.py模块中, def make_atari(env_id, max_episode_steps=None ...

  4. golang 指定权限是 0o755 而不是 0755

    在Go语言中,当指定文件权限时,使用前缀 0o 来明确表示八进制数是一种推荐的做法. 这是因为在Go语言中,八进制字面量必须以 0o 或 0O 开头,后跟八进制数字(0-7). 这种语法是从 Go 1 ...

  5. LVGL line组件

    目录 一.Line(线条)的概念 二.线条组件的使用 1.创建线条对象 2.设置点数组 3.确定y轴的方向(可选) 4.设置线条风格(可选) 4.1创建风格 4.2设置风格 5.将创建好的线段组件添加 ...

  6. SpringBoot整合RabbitMQ 通俗易懂 超详细 【内含案例】

    SpringBoot结合RabbitMq SpringBoot 框架部署 HelloWorld 简单模式 Topic 通配符模式 一.SpringBoot 框架部署 1.创建Maven工程(我用的ID ...

  7. SMU Autumn 2023 Round 1(Div.1)

    SMU Autumn 2023 Round 1(Div.1) A. Set or Decrease(枚举) 题意就是你可以进行两种操作,将\(a_i-1\)或者令\(a_i\)等于\(a_j\),然后 ...

  8. curl可以访问虚拟机资源,但是宿主机浏览器不能访问

    如果想从宿主机访问到虚拟机内的php,需要关闭宿主机的代理,并且设置虚拟机内的防火墙不要屏蔽宿主机的ip. 设置虚拟机防火墙方法: 查找宿主机IP:win+r,输入ipconfig 打开虚拟机,输入s ...

  9. Linux命令cURL详解,并实现文件定时上传到ftp服务器的程序

    前言 前段时间群里讨论,想实现某个文件定时上传到服务器要怎么来实现.我记得之前做过 一个项目:为高通的iot模组编写FOTA功能:实现模组可以远程下载升级镜像包,实现版本升级功能.并当时使用的一个超级 ...

  10. Win32 动态库dll

    这两天学习动态库的练习,分享下方法 实例.封装窗口类的两种状态. 1.自定义窗口类QWnd 2.资源模板窗口对话框类 下面是dll的头文件,类的声明 #pragma once #ifndef _CLA ...