强化学习笔记之【SAC算法】


前言:

本文为强化学习笔记第四篇,第一篇讲的是Q-learning和DQN,第二篇DDPG,第三篇TD3

TD3比DDPG少了一个target_actor网络,其它地方有点小改动

CSDN主页:https://blog.csdn.net/rvdgdsva

博客园主页:https://www.cnblogs.com/hassle


STAND ALONE COMPLEX = S . A . C

首先,我们需要明确,Q-learning算法发展成DQN算法,DQN算法发展成为DDPG算法,而DDPG算法发展成TD3算法,TD3算法发展成SAC算法

Soft Actor-Critic (SAC) 是一种基于策略梯度的深度强化学习算法,它具有最大化奖励与最大化熵(探索性)的双重目标。SAC 通过引入熵正则项,使策略在决策时具有更大的随机性,从而提高探索能力。

一、SAC算法

OK,先用伪代码让你们感受一下SAC算法

# 定义 SAC 超参数
alpha = 0.2 # 熵正则项系数
gamma = 0.99 # 折扣因子
tau = 0.005 # 目标网络软更新参数
lr = 3e-4 # 学习率 # 初始化 Actor、Critic、Target Critic 网络和优化器
actor = ActorNetwork() # 策略网络 π(s)
critic1 = CriticNetwork() # 第一个 Q 网络 Q1(s, a)
critic2 = CriticNetwork() # 第二个 Q 网络 Q2(s, a)
target_critic1 = CriticNetwork() # 目标 Q 网络 1
target_critic2 = CriticNetwork() # 目标 Q 网络 2 # 将目标 Q 网络的参数设置为与 Critic 网络相同
target_critic1.load_state_dict(critic1.state_dict())
target_critic2.load_state_dict(critic2.state_dict()) # 初始化优化器
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=lr)
critic1_optimizer = torch.optim.Adam(critic1.parameters(), lr=lr)
critic2_optimizer = torch.optim.Adam(critic2.parameters(), lr=lr) # 经验回放池(Replay Buffer)
replay_buffer = ReplayBuffer() # SAC 训练循环
for each iteration:
# Step 1: 从 Replay Buffer 中采样一个批次 (state, action, reward, next_state)
batch = replay_buffer.sample()
state, action, reward, next_state, done = batch # Step 2: 计算目标 Q 值 (y)
with torch.no_grad():
# 从 Actor 网络中获取 next_state 的下一个动作
next_action, next_log_prob = actor.sample(next_state) # 目标 Q 值的计算:使用目标 Q 网络的最小值 + 熵项
target_q1_value = target_critic1(next_state, next_action)
target_q2_value = target_critic2(next_state, next_action)
min_target_q_value = torch.min(target_q1_value, target_q2_value) # 目标 Q 值 y = r + γ * (最小目标 Q 值 - α * next_log_prob)
target_q_value = reward + gamma * (1 - done) * (min_target_q_value - alpha * next_log_prob) # Step 3: 更新 Critic 网络
# Critic 1 损失
current_q1_value = critic1(state, action)
critic1_loss = F.mse_loss(current_q1_value, target_q_value) # Critic 2 损失
current_q2_value = critic2(state, action)
critic2_loss = F.mse_loss(current_q2_value, target_q_value) # 反向传播并更新 Critic 网络参数
critic1_optimizer.zero_grad()
critic1_loss.backward()
critic1_optimizer.step() critic2_optimizer.zero_grad()
critic2_loss.backward()
critic2_optimizer.step() # Step 4: 更新 Actor 网络
# 通过 Actor 网络生成新的动作及其 log 概率
new_action, log_prob = actor.sample(state) # 计算 Actor 的目标损失:L = α * log_prob - Q1(s, π(s))
q1_value = critic1(state, new_action)
actor_loss = (alpha * log_prob - q1_value).mean() # 反向传播并更新 Actor 网络参数
actor_optimizer.zero_grad()
actor_loss.backward()
actor_optimizer.step() # Step 5: 软更新目标 Q 网络参数
with torch.no_grad():
for param, target_param in zip(critic1.parameters(), target_critic1.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) for param, target_param in zip(critic2.parameters(), target_critic2.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

二、SAC算法Latex解释

1、初始化 Actor、Critic1、Critic2、TargetCritic1 、TargetCritic2 网络

2、Buffer中采样 (state, action, reward, next_state)

3、Actor 输入 next_state 对应输出 next_action 和 next_log_prob

4、Actor 输入 state 对应输出 new_action 和 log_prob

5、Critic1 和 Critic2 分别输入next_state 和 next_action 取其中较小输出经熵正则计算得 target_q_value

6、使用 MSE_loss(Critic1(state, action), target_q_value) 更新 Critic1

7、使用 MSE_loss(Critic2(state, action), target_q_value) 更新 Critic2

8、使用 (alpha * log_prob - critic1(state, new_action)).mean() 更新 Actor


三、SAC五大网络和模块

SAC 算法 中,Actor、Critic1、Critic2、Target Critic1 和 Target Critic2 网络是核心模块,它们分别用于输出动作、评估状态-动作对的价值,并通过目标网络进行稳定的更新。

3.1 Actor 网络

Actor 网络用于在给定状态下输出一个高斯分布的均值和标准差(即策略)。它是通过神经网络近似的随机策略。用于选择动作。

import torch
import torch.nn as nn class ActorNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(ActorNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, 256)
self.fc2 = nn.Linear(256, 256)
self.mean_layer = nn.Linear(256, action_dim) # 输出动作的均值
self.log_std_layer = nn.Linear(256, action_dim) # 输出动作的log标准差 def forward(self, state):
x = torch.relu(self.fc1(state))
x = torch.relu(self.fc2(x))
mean = self.mean_layer(x) # 输出动作均值
log_std = self.log_std_layer(x) # 输出 log 标准差
log_std = torch.clamp(log_std, min=-20, max=2) # 限制标准差范围
return mean, log_std def sample(self, state):
mean, log_std = self.forward(state)
std = torch.exp(log_std) # 将 log 标准差转为标准差
normal = torch.distributions.Normal(mean, std)
action = normal.rsample() # 通过重参数化技巧进行采样
log_prob = normal.log_prob(action).sum(-1) # 计算 log 概率
return action, log_prob

3.2 Critic1 和 Critic2 网络

Critic 网络用于计算状态-动作对的 Q 值,SAC 使用两个 Critic 网络(Critic1 和 Critic2)来缓解 Q 值的过估计问题。

class CriticNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(CriticNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, 256)
self.fc2 = nn.Linear(256, 256)
self.q_value_layer = nn.Linear(256, 1) # 输出 Q 值 def forward(self, state, action):
x = torch.cat([state, action], dim=-1) # 将 state 和 action 作为输入
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
q_value = self.q_value_layer(x) # 输出 Q 值
return q_value

3.3 Target Critic1 和 Target Critic2 网络

Target Critic 网络的结构与 Critic 网络相同,用于稳定 Q 值更新。它们通过软更新(即在每次训练后慢慢接近 Critic 网络的参数)来保持训练的稳定性。

class TargetCriticNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(TargetCriticNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, 256)
self.fc2 = nn.Linear(256, 256)
self.q_value_layer = nn.Linear(256, 1) # 输出 Q 值 def forward(self, state, action):
x = torch.cat([state, action], dim=-1) # 将 state 和 action 作为输入
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
q_value = self.q_value_layer(x) # 输出 Q 值
return q_value

3.4 软更新模块

在 SAC 中,目标网络会通过软更新逐渐逼近 Critic 网络的参数。每次更新后,目标网络参数会按照 ττ 的比例向 Critic 网络的参数靠拢。

def soft_update(critic, target_critic, tau=0.005):
for param, target_param in zip(critic.parameters(), target_critic.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

3.5 总结

  1. 初始化网络和参数:

    • Actor 网络:用于选择动作。
    • Critic 1 和 Critic 2 网络:用于估计 Q 值。
    • Target Critic 1 和 Target Critic 2:与 Critic 网络架构相同,用于生成更稳定的目标 Q 值。
  2. 目标 Q 值计算:
    • 使用目标网络计算下一状态下的 Q 值。
    • 取两个 Q 网络输出的最小值,防止 Q 值的过估计。
    • 引入熵正则项,计算公式:$$y=r+\gamma\cdot\min(Q_1,Q_2)-\alpha\cdot\log\pi(a|s)$$
  3. 更新 Critic 网络:
    • 最小化目标 Q 值与当前 Q 值的均方误差 (MSE)。
  4. 更新 Actor 网络:
    • 最大化目标损失:$$L=\alpha\cdot\log\pi(a|s)-Q_1(s,\pi(s))$$,即在保证探索的情况下选择高价值动作。
  5. 软更新目标网络:
    • 软更新目标 Q 网络参数,使得目标网络参数缓慢向当前网络靠近,避免振荡。

强化学习笔记之【SAC算法】的更多相关文章

  1. 强化学习-学习笔记7 | Sarsa算法原理与推导

    Sarsa算法 是 TD算法的一种,之前没有严谨推导过 TD 算法,这一篇就来从数学的角度推导一下 Sarsa 算法.注意,这部分属于 TD算法的延申. 7. Sarsa算法 7.1 推导 TD ta ...

  2. GMM高斯混合模型学习笔记(EM算法求解)

    提出混合模型主要是为了能更好地近似一些较复杂的样本分布,通过不断添加component个数,能够随意地逼近不论什么连续的概率分布.所以我们觉得不论什么样本分布都能够用混合模型来建模.由于高斯函数具有一 ...

  3. CS229 - MachineLearning - 12 强化学习笔记

    Ng的机器学习课,课程资源:cs229-课件    网易公开课-视频 问题数学模型: 马尔科夫过程五元组{S.a.Psa.γ.R},分别对应 {状态.行为.状态s下做出a行为的概率.常数.回报}. 一 ...

  4. 【学习笔记】 Adaboost算法

    前言 之前的学习中也有好几次尝试过学习该算法,但是都无功而返,不仅仅是因为该算法各大博主.大牛的描述都比较晦涩难懂,同时我自己学习过程中也心浮气躁,不能专心. 现如今决定一口气肝到底,这样我明天就可以 ...

  5. 挑子学习笔记:DBSCAN算法的python实现

    转载请标明出处:https://www.cnblogs.com/tiaozistudy/p/dbscan_algorithm.html DBSCAN(Density-Based Spatial Clu ...

  6. 普通平衡树学习笔记之Splay算法

    前言 今天不容易有一天的自由学习时间,当然要用来"学习".在此记录一下今天学到的最基础的平衡树. 定义 平衡树是二叉搜索树和堆合并构成的数据结构,它是一 棵空树或它的左右两个子树的 ...

  7. 【算法学习笔记】Meissel-Lehmer 算法 (亚线性时间找出素数个数)

    「Meissel-Lehmer 算法」是一种能在亚线性时间复杂度内求出 \(1\sim n\) 内质数个数的一种算法. 在看素数相关论文时发现了这个算法,论文链接:Here. 算法的细节来自 OI w ...

  8. python学习笔记(MD5算法)

    博主最近进度停滞了 对web开发理解欠缺好多内容 今天整理下MD5算法,这个涉及到mysql数据库存储用户表密码字段的时候 一般是带有加密的 # -*- coding: utf-8 -*- impor ...

  9. 【学习笔记】分类算法-k近邻算法

    k-近邻算法采用测量不同特征值之间的距离来进行分类. 优点:精度高.对异常值不敏感.无数据输入假定 缺点:计算复杂度高.空间复杂度高 使用数据范围:数值型和标称型 用例子来理解k-近邻算法 电影可以按 ...

  10. Web安全学习笔记之DES算法实例详解

    转自http://www.hankcs.com/security/des-algorithm-illustrated.html 译自J. Orlin Grabbe的名作<DES Algorith ...

随机推荐

  1. 【Maxwell】03 定向监听&全量输出

    一.定向监听 定向监听,即只监听某一个特定的表,或者库 1.创建样本案例 -- 创建监听的库(演示样本) CREATE DATABASE `test-db-2` CHARACTER SET 'utf8 ...

  2. 【JavaWeb】HttpClient

    需要的依赖: <!-- https://mvnrepository.com/artifact/org.apache.httpcomponents/httpclient --> <de ...

  3. 【SpringBoot】13 数据访问P1 整合Jdbc

    SpringBoot与数据访问概述: 对于数据访问层,无论是SQL还是NOSQL,Spring Boot默认采用整合Spring Data的方式进行统一处理, 添加大量自动配置,屏蔽了很多设置.引入各 ...

  4. 什么是MMU

    一.MMU的定义   MMU是Memory Management Unit的缩写,中文名是内存管理单元,有时也称作分页内存管理单元(Paged Memory Management Unit,缩写为PM ...

  5. springboot代码自动生成

    在项目开始阶段经常需要自动生成一批代码,如果使用了mybatis则可以使用mybatis plus就可以生成mybatis相关代码.不过经常项目中还有一些mvc代码需要生成,比如说前端代码.相关sql ...

  6. 我是如何使用 vue2+element-ui 处理负责表单,避免单文件过大的问题

    引言 在工作中我经常需要处理一些复杂.动态表单,但是随着需求不断迭代,我们也许会发现曾经两三百行的.vue文件现在不知不觉到了两千行,三千行,甚至更多... 这对于一个需要长期维护的项目,无疑是增加了 ...

  7. 【Jenkins】之自动化测试持续集成

    一.创建jenkins项目 选择节点 创建指定名称的目录名: 写命令,执行shell: 命令填写: # 引入电脑配置文件 #. ~/.bash_profile cd Python_Interface ...

  8. 【Jmeter】之批量处理多接口压力测试

    一.需求前提 1.有以下三个步骤: ①创建单据 ②审核单据 ③确认单据 让三个相关接口进行一连串批量请求操作,直到所有批量数据确认单据成功. 二.测试计划 需要说明的是,因为每个接口可能处理的不太一样 ...

  9. seata 下载及安装

    分布式事务 参考文章: 分布式事务实战方案汇总 https://www.cnblogs.com/yizhiamumu/p/16625677.html 分布式事务原理及解决方案案例https://www ...

  10. pycharm批量注释

    pycharm批量注释不像是spyder可以鼠标右键选择,pycharm是要用快捷键的,选中要注释的代码,然后快捷键就可以了. 注释代码和取消注释代码的快捷键都一样ctrl + /