参考文章
浅谈线段树
线段树
线段树是基于分治思想的二叉树,用于维护区间信息(区间和、区间最值、区间GCD等),可以在logn的时间内执行区间修改和区间查询。
线段树中每个叶子节点储存元素本身,非子叶节点存储区间内元素的统计值。
- 优点:较快完成区间更新和查询
- 缺点:空间大(2n - 4n)
代码实现
结构体包含三个变量:l,r,sum;
所有端点和区间和。
构建
递归建树
父节点编号为p,左孩子为2p,右孩子为2p+1;
1 2 3 4 5 6 7 8 9 10 11 12 13
| void build(int l,int r,int k) { tree[k].l=l;tree[k].r=r; if(l==r) { scanf("%d",&tree[k].w); return ; } int m=(l+r)/2; build(l,m,k*2); build(m+1,r,k*2+1); tree[k].w=tree[k*2].w+tree[k*2+1].w; }
|
单点查询
1 2 3 4 5 6 7 8 9 10 11
| void ask(int k) { if(tree[k].l==tree[k].r) { ans=tree[k].w; return ; } int m=(tree[k].l+tree[k].r)/2; if(x<=m) ask(k*2); else ask(k*2+1); }
|
单点修改
从根节点介入,从下往上更新。
1 2 3 4 5 6 7 8 9 10 11 12
| void add(int k) { if(tree[k].l==tree[k].r) { tree[k].w+=y; return; } int m=(tree[k].l+tree[k].r)/2; if(x<=m) add(k*2); else add(k*2+1); tree[k].w=tree[k*2].w+tree[k*2+1].w; }
|
区间查询
拆分与拼凑
- 区间完全覆盖,回溯,返回sum;
- 左节点与[x,y]有重叠,递归返回左子树;
- 右节点与[x,y]有重叠,递归返回右子树;
1 2 3 4 5 6 7 8 9 10 11
| void sum(int k) { if(tree[k].l>=x&&tree[k].r<=y) { ans+=tree[k].w; return; } int m=(tree[k].l+tree[k].r)/2; if(x<=m) sum(k*2); if(y>m) sum(k*2+1); }
|
区间修改
当[x,y]完全覆盖区间[a,b]时,先修改该区间的sum值,再打上一个懒标记,下次再查询时再带上懒标记,可以把每次修改和查询的时间都控制到logn。
懒标记
1 2 3 4 5 6 7 8
| void down(int k) { tree[k*2].f+=tree[k].f; tree[k*2+1].f+=tree[k].f; tree[k*2].w+=tree[k].f*(tree[k*2].r-tree[k*2].l+1); tree[k*2+1].w+=tree[k].f*(tree[k*2+1].r-tree[k*2+1].l+1); tree[k].f=0; }
|
懒标记的区间修改
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| void add(int k) { if(tree[k].l>=a&&tree[k].r<=b) { tree[k].w+=(tree[k].r-tree[k].l+1)*x; tree[k].f+=x; return; } if(tree[k].f) down(k); int m=(tree[k].l+tree[k].r)/2; if(a<=m) add(k*2); if(b>m) add(k*2+1); tree[k].w=tree[k*2].w+tree[k*2+1].w; }
|
懒标记的单点查询
1 2 3 4 5 6 7 8 9 10 11 12
| void ask(int k) { if(tree[k].l==tree[k].r) { ans=tree[k].w; return ; } if(tree[k].f) down(k); int m=(tree[k].l+tree[k].r)/2; if(x<=m) ask(k*2); else ask(k*2+1); }
|
懒标记的区间查询
1 2 3 4 5 6 7 8 9 10 11 12
| void sum(int k) { if(tree[k].l>=x&&tree[k].r<=y) { ans+=tree[k].w; return; } if(tree[k].f) down(k) int m=(tree[k].l+tree[k].r)/2; if(x<=m) sum(k*2); if(y>m) sum(k*2+1); }
|