题解-luogu-p4103大工程

题目链接

给定一棵有 $n$ 个节点的树,边权为 $1$ 。共有 $q$ 次询问,每次给出 $k$ 个节点,求: $k$ 个节点间两两距离之总和;最短距离;最长距离。

$1\le n \le 1000000,1 \le q \le 1000000, \Sigma{k}\le 2 * n$

感谢@tth37 的贡献

本题用到了一些点分治的思想。

考虑 $q=1$ 的情况。一种朴素的做法是:枚举当前节点的所有子节点,并计算子树间关键点形成的路径、更新答案。但是本题与一般点分治题目略有不同,我们可以通过预处理子树信息来优化点分治过程。

稍加观察可以发现,只需预处理每个子树中树根到关键点的最小距离、最大距离,以及子树中关键点的个数、所有关键点到树根的距离总和即可完成点分治全部过程,时间复杂度 $O(n)$ 。

对于 $q\not=1$ 的情况,观察到 $\Sigma{k}$ 与 $n$ 同阶,可以对每次查询建立一棵虚树,在虚树上点分治即可。

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 1000005;
typedef long long ll;
struct Edge {int v, w; Edge(int a, int b) {v = a, w = b;}};
struct Key {int u, dfn;}keys[MAXN];
bool cmp(Key a, Key b) {return a.dfn < b.dfn;}
vector<Edge> G[MAXN], VT[MAXN];
int N, Q, K;
int f[MAXN][21], dep[MAXN], dfn[MAXN], dfn_idx;
int lg[MAXN];
bool h[MAXN];
ll g[MAXN];
int m[MAXN], n[MAXN];
int c[MAXN];
ll ans1;
int ans2, ans3;

inline void dfs0(int u, int fa) {
dfn[u] = ++dfn_idx;
dep[u] = dep[fa] + 1;
f[u][0] = fa;
for (int i = 1; i <= lg[dep[u]]; ++i)
f[u][i] = f[f[u][i - 1]][i - 1];
for (vector<Edge>::iterator it = G[u].begin(); it != G[u].end(); it++) {
int v = it -> v;
if (v == fa) continue;
dfs0(v, u);
}
}

void dfs1(int u) {
g[u] = 0;
c[u] = h[u];
m[u] = 0x3f3f3f3f;
n[u] = -0x3f3f3f3f;
if (h[u]) m[u] = n[u] = 0;
for (vector<Edge>::iterator it = VT[u].begin(); it != VT[u].end(); it++) {
int v = it -> v, w = it -> w;
dfs1(v);
c[u] += c[v];
g[u] += g[v] + 1ll * w * c[v];
m[u] = min(m[u], w + m[v]);
n[u] = max(n[u], w + n[v]);
}
}

void dfs2(int u) {
ll sum = 0;
int cnt = h[u];
int minn = 0x3f3f3f3f, maxx = -0x3f3f3f3f;
if (h[u]) minn = maxx = 0;
for (vector<Edge>::iterator it = VT[u].begin(); it != VT[u].end(); it++) {
int v = it -> v, w = it -> w;
ans1 += 1ll * sum * c[v] + 1ll * w * cnt * c[v] + 1ll * g[v] * cnt;
ans2 = min(ans2, minn + w + m[v]);
ans3 = max(ans3, maxx + w + n[v]);
sum += g[v] + 1ll * c[v] * w;
cnt += c[v];
minn = min(minn, w + m[v]);
maxx = max(maxx, w + n[v]);
dfs2(v);
}
h[u] = 0;
VT[u].clear();
}

inline int Lca(int u, int v) {
if (dep[u] < dep[v]) swap(u, v);
while (dep[u] > dep[v]) u = f[u][lg[dep[u] - dep[v]]];
if (u == v) return u;
for (int i = lg[dep[u]]; i >= 0; --i) {
if (f[u][i] != f[v][i]) u = f[u][i], v = f[v][i];
}
return f[u][0];
}

int main() {
for (register int i = 2; i <= 1000000; ++i)
lg[i] = lg[i >> 1] + 1;
scanf("%d", &N);
for (register int i = 1; i < N; ++i) {
int u, v;
scanf("%d%d", &u, &v);
G[u].push_back(Edge(v, 1));
G[v].push_back(Edge(u, 1));
}
dfs0(1, 0);
scanf("%d", &Q);
while (Q--) {
scanf("%d", &K);
for (register int i = 1; i <= K; ++i) {
int u;
scanf("%d", &u);
h[u] = 1;
keys[i].u = u, keys[i].dfn = dfn[u];
}
sort(keys + 1, keys + K + 1, cmp);
stack<int> s;
s.push(1);
for (register int i = 1; i <= K; ++i) {
int u = keys[i].u;
if (u == 1) continue;
int lca = Lca(u, s.top());
while (s.top() != lca) {
int tmp = s.top(); s.pop();
if (dfn[s.top()] < lca) s.push(lca);
VT[s.top()].push_back(Edge(tmp, dep[tmp] - dep[s.top()]));
}
s.push(u);
}
while (s.top() != 1) {
int tmp = s.top(); s.pop();
VT[s.top()].push_back(Edge(tmp, dep[tmp] - dep[s.top()]));
}
dfs1(1);
ans1 = ans3 = 0;
ans2 = 0x3f3f3f3f;
dfs2(1);
printf("%lld %d %d\n", ans1, ans2, ans3);
}
return 0;
}

Comments

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×