Rust实现线段树和懒标记
参考各家代码,用Rust实现了线段树和懒标记。
由于使用了泛型,很多操作都要用闭包自定义实现。
看代码。
// 线段树定义
pub struct SegmentTree<T: Clone>
{
pub data: Vec<T>,
tree: Vec<Option<T>>,
marker: Vec<T>, //懒标记。
query_op: Box<dyn Fn(T, T) -> T>, //查询时,对所有查询元素做的操作。比如加法,就是求区间的所有元素的和。
marker_marker_op: Box<dyn Fn(T, T) -> T>, //marker加到marker上时,对marker的操作。通常我们要marker[i] += p; 来更新标记,但是泛型实现不了,并且考虑到有些用户有别的需求,所以用闭包包装。
marker_t_op: Box<dyn Fn(T, T) -> T>, //marker应用到T时,对T的操作。考虑到有些用户有别的需求,所以用闭包包装。
marker_mul_usize: Box<dyn Fn(T, usize) -> T>, //marker乘usize的方法。这个没法通过要求满足Mul trait自动实现。由于使用了泛型,连乘法都要交给闭包实现。。。
}
impl<T: Clone + Default + Copy + PartialEq> SegmentTree<T> {
pub fn new(
data: Vec<T>,
query_op: Box<dyn Fn(T, T) -> T>,
marker_marker_op: Box<dyn Fn(T, T) -> T>,
marker_t_op: Box<dyn Fn(T, T) -> T>,
marker_mul_usize: Box<dyn Fn(T, usize) -> T>,
) -> Self {
let data_len = data.len();
let mut tr = Self {
data,
marker: vec![T::default(); 4 * data_len], //四倍原数据大小
tree: vec![None; 4 * data_len], //四倍原数据大小
query_op,
marker_marker_op,
marker_t_op,
marker_mul_usize,
};
tr.build();
tr
}
#[inline]
pub fn get(&self, index: usize) -> Option<&T> {
self.data.get(index)
}
#[inline]
pub fn len(&self) -> usize {
self.data.len()
}
#[inline]
fn left_child(index: usize) -> usize {
2 * index + 1
}
#[inline]
fn right_child(index: usize) -> usize {
2 * index + 2
}
#[inline]
fn build(&mut self) {
self.build_segment_tree(0, 0, self.data.len() - 1);
}
// 递归Build
fn build_segment_tree(&mut self, tree_index: usize, left: usize, right: usize) {
if left == right {
self.tree[tree_index] = Some(self.data[left]);
return;
}
let left_tree_index = Self::left_child(tree_index);
let right_tree_index = Self::right_child(tree_index);
let mid = (right - left) / 2 + left;
self.build_segment_tree(left_tree_index, left, mid);
self.build_segment_tree(right_tree_index, mid + 1, right);
// 左右子树数据处理方式
if let Some(l) = self.tree[left_tree_index] {
if let Some(r) = self.tree[right_tree_index] {
self.tree[tree_index] = Some((self.query_op)(l, r))
}
}
}
// 返回对线段树的全部元素做query_op操作的结果
#[inline]
pub fn query_all(&mut self) -> T {
self.recursion_query(0, self.data.len() - 1, 0, 0, self.data.len() - 1)
}
// 返回对线段树的[l..r]范围全部元素做query_op操作的结果
pub fn query(&mut self, l: usize, r: usize) -> Result<T, &'static str> {
if l > self.data.len() || r > self.data.len() || l > r {
return Err("索引错误");
}
if l == r {
return Ok(self.data[l]);
}
Ok(self.recursion_query(l, r, 0, 0, self.data.len() - 1))
}
// 在index表示的[current_left,current_right]范围中查询[l..r]值
fn recursion_query(
&mut self,
l: usize,
r: usize,
index: usize,
current_left: usize,
current_right: usize,
) -> T {
if l > current_right || r < current_left {
return T::default();
}
if l == current_left && r == current_right {
if let Some(d) = self.tree[index] {
if l == r {
self.data[l] = d;
}
return d;
}
return T::default();
}
self.push_down(index, current_right - current_left + 1);
let mid = current_left + (current_right - current_left) / 2;
if l >= mid + 1 {
return self.recursion_query(l, r, Self::right_child(index), mid + 1, current_right);
} else if r <= mid {
return self.recursion_query(l, r, Self::left_child(index), current_left, mid);
}
let l_res = self.recursion_query(l, mid, Self::left_child(index), current_left, mid);
let r_res =
self.recursion_query(mid + 1, r, Self::right_child(index), mid + 1, current_right);
(self.query_op)(l_res, r_res)
}
// 更新index为val
pub fn set(&mut self, index: usize, val: T) -> Result<(), &'static str> {
if index >= self.data.len() {
return Err("索引超过线段树长度");
}
// 更新数据
self.data[index] = val;
// 递归更新树
self.recursion_set(0, 0, self.data.len() - 1, index, val);
Ok(())
}
// 递归更新树
fn recursion_set(&mut self, index_tree: usize, l: usize, r: usize, index: usize, val: T) {
if l == r {
self.tree[index_tree] = Some(val);
return;
}
let mid = l + (r - l) / 2;
let left_child = Self::left_child(index_tree);
let right_child = Self::right_child(index_tree);
if index >= mid + 1 {
self.recursion_set(right_child, mid + 1, r, index, val);
} else {
self.recursion_set(left_child, l, mid, index, val);
}
// 左右子树数据求和
if let Some(l_d) = self.tree[left_child] {
if let Some(r_d) = self.tree[right_child] {
self.tree[index_tree] = Some((self.query_op)(l_d, r_d));
}
}
}
// 应用所有懒标记到data数组上
#[inline]
pub fn apply_marker_all(&mut self) {
self.apply_marker_lr(0, self.data.len() - 1);
}
// 应用懒标记到[l:r]数据范围
#[inline]
pub fn apply_marker_lr(&mut self, l: usize, r: usize) {
self.apply_marker(l, r, 0, 0, self.data.len() - 1);
}
fn apply_marker(
&mut self,
l: usize,
r: usize,
index: usize,
current_l: usize,
current_r: usize,
) {
if current_l > r || current_r < l || r >= self.data.len() {
return; // 区间无交集
} else {
// 与目标区间有交集,但不包含于其中
if current_l == current_r {
if let Some(d) = self.tree[index] {
self.data[current_l] = d;
}
return;
}
let mid = (current_l + current_r) / 2;
self.push_down(index, current_r - current_l + 1);
self.apply_marker(l, r, Self::left_child(index), current_l, mid); // 递归地往下寻找
self.apply_marker(l, r, Self::right_child(index), mid + 1, current_r);
self.tree[index] = Some((self.query_op)(
self.tree[Self::left_child(index)].unwrap(),
self.tree[Self::right_child(index)].unwrap(),
));
// 根据子节点更新当前节点的值
}
}
#[inline]
pub fn update_interval(&mut self, l: usize, r: usize, delta: T) {
self.update(l, r, delta, 0, 0, self.data.len() - 1);
}
// 传递marker到下级
fn push_down(&mut self, index: usize, len: usize) {
self.marker[Self::left_child(index)] =
(self.marker_marker_op)(self.marker[index], self.marker[Self::left_child(index)]); // 标记向下传递
self.marker[Self::right_child(index)] =
(self.marker_marker_op)(self.marker[index], self.marker[Self::right_child(index)]);
if self.tree[Self::left_child(index)].is_some() {
self.tree[Self::left_child(index)] = Some((self.marker_t_op)(
(self.marker_mul_usize)(self.marker[index], len - (len / 2)),
self.tree[Self::left_child(index)].unwrap(),
));
}
if self.tree[Self::right_child(index)].is_some() {
self.tree[Self::right_child(index)] = Some((self.marker_t_op)(
(self.marker_mul_usize)(self.marker[index], len / 2),
self.tree[Self::right_child(index)].unwrap(),
));
}
self.marker[index] = T::default(); // 清除标记
}
fn update(
&mut self,
l: usize,
r: usize,
delta: T,
index: usize,
current_l: usize,
current_r: usize,
) {
if current_l > r || current_r < l {
return; // 区间无交集
} else if current_l >= l && current_r <= r {
// 当前节点对应的区间包含在目标区间中
if self.tree[index].is_some() {
// 更新当前区间的值
self.tree[index] = Some((self.query_op)(
self.tree[index].unwrap(),
(self.marker_mul_usize)(delta, current_r - current_l + 1),
));
}
// 如果不是叶子节点
if current_r > current_l {
// 给当前区间打上标记
self.marker[index] = (self.marker_marker_op)(delta, self.marker[index]);
}
} else {
// 与目标区间有交集,但不包含于其中
let mid = (current_l + current_r) / 2;
self.push_down(index, current_r - current_l + 1);
self.update(l, r, delta, Self::left_child(index), current_l, mid); // 递归地往下寻找
self.update(l, r, delta, Self::right_child(index), mid + 1, current_r);
self.tree[index] = Some((self.query_op)(
self.tree[Self::left_child(index)].unwrap(),
self.tree[Self::right_child(index)].unwrap(),
)); // 根据子节点更新当前节点的值
}
}
}
fn main() {
let mut tr: SegmentTree<i32> = SegmentTree::new(
vec![1, 3, 4, 0, 0, 4, 5, 0],
Box::new(|a, b| a + b),
Box::new(|a, b| a + b),
Box::new(|a, b| a + b),
Box::new(|a, b| a * (b as i32)),
);
let _ = tr.set(1, 2); //点更新,即把data[1]设为2
tr.update_interval(0, 2, -1); //区间更新,即[0:2]每个元素减1
tr.update_interval(1, 3, 2); //区间更新,即[1:3]每个元素加2
tr.apply_marker_all(); //应用全部marker到data数组
println!("{}", tr.query_all()); //输出19,即全部元素的和
println!("{:?}", tr.data); //输出[0, 3, 5, 2, 0, 4, 5, 0]
}
做一道题验证一下这个线段树的正确性,直接看我写的1589. 所有排列中的最大和题解即可(虽然这道题用差分数组最快,但是作为线段树验证还是很方便的)。
Rust实现线段树和懒标记的更多相关文章
- 线段树初步&&lazy标记
线段树 一.概述: 线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点. 对于线段树中的每一个非叶子节点[a,b],它的左儿子表示的区间为[a, ...
- 【BZOJ-2892&1171】强袭作战&大sz的游戏 权值线段树+单调队列+标记永久化+DP
2892: 强袭作战 Time Limit: 50 Sec Memory Limit: 512 MBSubmit: 45 Solved: 30[Submit][Status][Discuss] D ...
- BZOJ 1798 (线段树||分块)的标记合并
我原来准备做方差的.. 结果发现不会维护两个标记.. 就是操作变成一个 a*x+b ,每次维护a , b 即可 加的时候a=1 ,b=v 乘的时候a=v ,b=0 #include <cstdi ...
- FZU 2171(线段树的延迟标记)
题意:容易理解. 分析:时隔很久,再一次写了一道线段树的代码,之前线段树的题也做了不少,包括各种延迟标记,但是在组队分任务之后,我们队的线段树就交给了另外一个队友在搞, 然后我就一直没去碰线段树的题了 ...
- POJ 3237 Tree (树链剖分 路径剖分 线段树的lazy标记)
题目链接:http://poj.org/problem?id=3237 一棵有边权的树,有3种操作. 树链剖分+线段树lazy标记.lazy为0表示没更新区间或者区间更新了2的倍数次,1表示为更新,每 ...
- hdu 1828 Picture(线段树 || 普通hash标记)
http://acm.hdu.edu.cn/showproblem.php?pid=1828 Picture Time Limit: 6000/2000 MS (Java/Others) Mem ...
- poj3468 线段树的懒惰标记
题目链接:poj3468 题意:给定一段数组,有两种操作,一种是给某段区间加c,另一种是查询一段区间的和 思路:暴力的方法是每次都给这段区间的点加c,查询也遍历一遍区间,复杂度是n*n,肯定过不去,另 ...
- HDU 4107 Gangster(线段树 特殊懒惰标记)
两种做法. 第一种:标记区间最大值和最小值,若区间最小值>=P,则本区间+2c,若区间最大值<P,则本区间+c.非常简单的区间更新. 最后发一点牢骚:最后query查一遍就行,我这个2B竟 ...
- 浅谈算法——线段树之Lazy标记
一.前言 前面我们已经知道线段树能够进行单点修改和区间查询操作(基本线段树).那么如果需要修改的是一个区间该怎么办呢?如果是暴力修改到叶子节点,复杂度即为\(O(nlog n)\),显然是十分不优秀的 ...
- HDU 3397 线段树 双懒惰标记
这个是去年遗留历史问题,之前思路混乱,搞了好多发都是WA,就没做了 自从上次做了大白书上那个双重懒惰标记的题目,做这个就思路很清晰了 跟上次大白上那个差不多,这个也是有一个sets标记,代表这个区间全 ...
随机推荐
- 文心一言 VS 讯飞星火 VS chatgpt (96)-- 算法导论9.3 1题
一.用go语言,在算法 SELECT 中,输人元素被分为每组 5 个元素.如果它们被分为每组 7个元素,该算法仍然会是线性时间吗?证明:如果分成每组 3 个元素,SELECT 的运行时间不是线性的. ...
- 【RocketMQ】事务实现原理总结
RocketMQ事务的使用场景 单体架构下的事务 在单体系统的开发过程中,假如某个场景下需要对数据库的多张表进行操作,为了保证数据的一致性,一般会使用事务,将所有的操作全部提交或者在出错的时候全部回滚 ...
- CIC滤波器仿真与实验过程及结果记录
整理于2023-10-08 0.0 前言: 前面介绍了使用matlab中的Filter Designer工具箱进行CIC抽取滤波器设计的仿真过程与结果.下面在前面的基础上针对现有的[正点原子ZYNQ] ...
- Arduino基础入门之三按键开关
目的:通过读取按键开关的信号,实现其他器件的控制 难点:下拉电阻和上拉电阻 一.关于按键开关 按键开关如上图[1]所示,但我拿到实物,最令我头疼的是按钮下边4个角,我不知那两边是相通的(就是和图中12 ...
- 安装vscode
1.下载vscode安装包 因为vscode官网下载太慢, 所以从360的软件库下载: https://baoku.360.cn/soft/search?kw=vscode 2.直接点击安装 3.设置 ...
- Epic资源转到unity的方法
众所周知,unity中的素材主要是通过unity资源商店获取的.但是unity资源商店的白嫖机会太少了,而隔壁UE的Epic资源商店就有每月免费的资源,不白嫖成何体统?但是UE咱也不会用啊,白嫖的资源 ...
- 使用 Ant Design Vue 你可能会遇到的14个问题
公司有一个新需求,在原来项目基础上开发,项目中使用 Ant Design Vue,版本是 1.X ,在此记录下遇到的问题:对于没有使用过或者使用程度不深的同学来说,希望可以帮助你在开发中遇到问题时有个 ...
- Django+celery+eventlet+flower+redis异步任务创建及查询实现
1.环境版本:Django 3.2.12celery 5.3.4eventlet 0.33.3flower 2.0.1redis 3.5.3项目名称:new_project 2.celery配置(se ...
- 业务出海、高效传输、动态加速,尽在云栖大会「CDN与边缘计算」专场
2023杭州·云栖大会,即将热力来袭. 一场云计算盛会,500+前沿话题,3000+科技展品,与阿里云一起,共赴72小时的Tech沉浸之旅. 今日,「CDN与边缘计算」Tech专场,重磅议题抢先知晓! ...
- 自然数的拆分问题(lgP2404)
dfs.又调了一个小时,窝果然菜 需要传递的变量分别为目前搜索的数字:目前所有选中数字的和:目前所选数字个数. 见注释. #include<bits/stdc++.h> using nam ...