题解-luogu-p2633 Count on a tree(COT)

题目链接

给定一棵$n$个节点的树,每个节点上有一个权值。对于$m$次询问,需要输出$u$到$v$的最短路径上第$k$小的点权。

$1\le n\le 100000,1 \le m\le 100000$

感谢@tth37 的贡献

这题不难呀,怎么调了这么久? ——Mr. G

前置知识是主席树。在利用主席树求解区间第K小数时可以发现,主席树是一种类似前缀和的数据结构,具有和前缀和类似的区间加减及差分等优秀性质。在求解线性区间的第K小数时,我们需要将该区间内的所有数值信息扔到一棵主席树中,并在这棵主席树上左右递归,以找到第K小数;同样的,我们可以类比树上前缀和的操作,定义$s[u]$为从根节点到第$u$号节点的“前缀主席树”(感性理解谢谢)。那么,包含$u$到$v$上所有数值信息的主席树就应该是:
$$
s[u]+s[v]-s[lca(u,v)]-s[fa[lca(u,v)]]
$$
理解上式后,问题基本可以解决了。另外注意离散化和主席树的代码细节。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include<bits/stdc++.h>
using namespace std;
// 离散化操作
#define id(x) (lower_bound(b+1,b+L+1,a[x])-b)
#define rid(x) (b[x])

const int MAXN = 100005;

struct Node {
int l, r, sum;
}node[10000005];
int head[MAXN],cnt;

vector<int> G[MAXN];

int N, M, L, lastans;
int a[MAXN], b[MAXN];
int f[MAXN][19], dep[MAXN];

inline void build(Node &u, int l, int r) {
u.sum = 0;
if (l == r) return;
int mid = (l + r) >> 1;
build(node[u.l = ++cnt], l, mid);
build(node[u.r = ++cnt], mid + 1, r);
}

inline void insert(Node c, Node &u, int l, int r, int p) {
u.sum = c.sum + 1;
if (l == r) return;
int mid = (l + r) >> 1;
if(p <= mid)
insert(node[c.l], node[u.l = ++cnt], l, mid, p), u.r = c.r;
else
insert(node[c.r], node[u.r = ++cnt], mid+1, r, p), u.l = c.l;
}

inline void dfs(int u, int fa) {
insert(node[head[fa]], node[head[u] = ++cnt], 1, L, id(u));
f[u][0] = fa;
dep[u] = dep[fa] + 1;
for (register int i = 1; i <= 18; ++i)
f[u][i] = f[f[u][i-1]][i-1];
for (vector<int>::iterator it = G[u].begin(); it != G[u].end(); it++) {
int v = *it;
if (v == fa) continue;
dfs(v, u);
}
}

inline int Lca(int u, int v) {
if (dep[u] < dep[v]) swap(u, v);
for (register int i = 18; i >= 0; --i) {
if (dep[f[u][i]] >= dep[v]) u = f[u][i];
}
if (u == v) return u;
for (register int i = 18; i >= 0; --i) {
if (f[u][i] != f[v][i])
u = f[u][i], v = f[v][i];
}
return f[u][0];
}

inline int query(Node x, Node y, Node z, Node w, int l, int r, int k) {
if (l == r) return l;
int sum = node[x.l].sum + node[y.l].sum - node[z.l].sum - node[w.l].sum;
int mid = (l + r) >> 1;
if(sum >= k) return query(node[x.l], node[y.l], node[z.l], node[w.l], l, mid, k);
return query(node[x.r], node[y.r], node[z.r], node[w.r], mid+1, r, k - sum);
}

inline int querypath(int u, int v, int k) {
int lca = Lca(u, v);
return rid(query(node[head[u]], node[head[v]], node[head[lca]], node[head[f[lca][0]]], 1, L, k));
}

int main() {
scanf("%d%d", &N, &M);
for (register int i = 1; i <= N; ++i)
scanf("%d", &a[i]), b[i] = a[i];
for (register int i = 1; i < N; ++i) {
int u, v;
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
sort(b + 1, b + N + 1);
L = unique(b + 1, b + N + 1) - (b + 1);
build(node[head[0] = ++cnt], 1, L);
dfs(1, 0);
for (register int i = 1; i <= M; ++i) {
int u, v, k;
scanf("%d%d%d", &u, &v, &k);
int nowans = querypath(u^lastans, v, k);
printf("%d\n", nowans);
lastans = nowans;
}
}

Comments

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×