// cnt是动态开点的节点下标 // ls/rs是储存左右节点下标的数组 // 初始调用时,int rt=0; update(rt,...); inlinevoidpushup(int rt){ a[rt] = a[ls[rt]] + a[rs[rt]]; } voidupdate(int& rt, int l, int r, int pos, int val)// 在pos位置插入元素val,注意是&rt { if (!rt) rt = ++cnt; if (l == r) { a[rt] += val; return; } int mid = l + r >> 1; if (pos <= mid) update(ls[rt], l, mid, pos, val); else update(rs[rt], mid + 1, r, pos, val); pushup(rt); }
查询
1 2 3 4 5 6 7 8 9 10 11 12 13
intquery(int rt, int l, int r, int x, int y) { if (!rt || x > y) return0; if(l==x&&r==y) return a[rt]; int mid = l + r >> 1; if(y<=mid) returnquery(ls[rt], l, mid, x, y); if(x>mid) returnquery(rs[rt], mid + 1, r, x, y); returnquery(ls[rt], l, mid, x, mid) + query(rs[rt], mid + 1, r, mid + 1, y); }
权值线段树
权值线段树:维护区间内各元素出现次数
作用
数列第k大/小
某个数的数列中排名
比某个数大的最小值/比某个数小的最大值
维护信息:叶子节点(元素出现次数),其他节点(子节点元素总数)
注意:本节的线段树实现没有记录L[rt] 和R[rt] 。
更新
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
voidpushup(int rt){ a[rt] = a[ls(rt)] + a[rs(rt)]; } voidupdate(int rt, int l, int r, int k, int cnt)// 值为k的数多了cnt个 { if (l == r) { a[rt] += cnt; // 动态开点:先前没有这个节点 return; } int mid = l + r >> 1; if (k <= mid) update(ls(rt), l, mid, k, cnt); else update(rs(rt), mid + 1, r, k, cnt); pushup(rt); }
查询元素val的个数:
1 2 3 4 5 6 7 8 9
intquery(int rt, int l, int r, int val) { if(l==r) return a[rt]; int mid = l + r >> 1; if(val<=mid) returnquery(ls(rt), l, mid, val); returnquery(rs(rt), mid + 1, r, val); }
查询第k大:注意mid的值,左边比右边多1,因此是k-mid
1 2 3 4 5 6 7 8 9
intquery(int rt, int l, int r, int k)// 查询第k大的数 { if (l == r) return l; int mid = l + r >> 1; if (k < mid) returnquery(rs(rt), mid + 1, r, k); // 右半部找第k大 returnquery(ls(rt), l, mid, k - mid); // 右边比k个数多,在左边继续找 }
查询区间[x,y] 有多少个数
1 2 3 4 5 6 7 8 9 10 11
intquery(int rt, int l, int r, int x, int y) { if (l == x && r == y) return a[rt]; int mid = l + r >> 1; if (y <= mid) returnquery(ls(rt), l, mid, x, y); if (x > mid) returnquery(rs(rt), mid + 1, r, x, y); returnquery(ls(rt), l, mid, x, mid) + query(rs(rt), mid + 1, r, mid + 1, y); }
voidbuild() { // m = 1 << ceil(_lg(n + 2)); for(m; m <= n + 1; m <<= 1); for(int i = 1; i <= n; i++) tree[i + m] = a[i]; for(int i = m - 1; i; i--) tree[i] = tree[ls(i)] + tree[rs(i)]; }
单点修改:
1 2 3 4 5
voidupdate(int pos, int val) { for(int i = pos + m; i; i >>= 1) tree[i] += val; }
查询区间和[l,r] ,首先扩充成(l−1,r+1) ,根据下面算法手推一下更好理解。
1 2 3 4 5 6 7 8 9 10
intquery(int l, int r) { int ans = 0; for(l += m-1, r += m+1; l ^ r ^ 1; l >>= 1, r >>= 1) { if (~l & 1) ans += tree[l ^ 1]; // 左端点是左儿子 if (r & 1) ans += tree[r ^ 1]; // 右端点是右儿子 } return ans; }