※ 글쓴이는 취미로 코딩을 익혀보는 사람이라 정확하지 않은 내용을 담고 있을 수 있다 ※
이번에 볼 문제는 백준 25973번 문제인 어지러운 트리이다.
문제는 아래 링크를 확인하자.
https://www.acmicpc.net/problem/25973
노드 r을 루트로 하는 노드 N개의 트리 위에서 적절한(O(NlgN)) 전처리후 노드 x를 LCA로 갖는 두 노드 a와 b의 순서쌍 (a, b) (단, a<b)의 개수를 묻는 쿼리를 개당 O(lgN)의 시간복잡도에 답을 구해 문제를 해결해보자.
두 노드 a와 b를 고르면 그 둘을 잇는 트리 상의 경로 또한 유일하게 정해진다. 이 경로 위에 x가 있으면서 a와 b 모두 (r을 루트로 삼았을 때) x를 루트로 갖는 서브트리 내에 포함되어있다면 (a,b)는 문제의 답이 될 수 있고, 그렇지 않다면 (a,b)가 답이 될 수 없다는 점을 관찰하자.
위와 같은 (a,b)의 개수는 각 x를 지나는 경로의 개수와, (x를 루트로 삼았을 때) r을 포함하면서 x를 포함하지 않는 가장 큰 서브트리에 들어있는 노드 개수 p만 알고 있다면 계산해낼 수 있다. 구체적으로 서술하면 다음과 같다. 위의 관찰에서 (a,b)의 개수는 (x를 지나는 경로의 개수) - (a와 b중 하나가 (r을 루트로 삼았을 때) x를 루트로 갖는 서브트리 내에 포함되지 않은 경로의 개수)와 같이 구할 수 있다는 것을 알 수 있다. 이는 p * (N - p -1)로 계산해낼 수 있다.
트리 위에서의 DP와 (sparse table을 이용한) LCA알고리즘의 원리를 이용해 위의 계산에 필요한 값을 전처리해두고 쿼리에 응답해 문제를 해결하자.
아래는 제출한 소스코드이다.
#include <iostream>
#include <vector>
#include <map>
using namespace std;
typedef long long ll;
int N, Q; int root = 1;
vector<int> G[200001];
int table[200001][18];
map<pair<int, int>, ll> mp;
ll routes[200001];
int depth[200001];
void solve() {
int q, z; cin >> q >> z;
if (q == 1) root = z;
else if (z == root) cout << routes[z] + N - 1 << '\n';
else {
int x = root, node;
if (depth[x] <= depth[z]) {
node = table[z][0];
}
else {
int k = 0, diff = depth[x] - depth[z] - 1;
while (diff) {
if (diff & 1) x = table[x][k];
k++, diff >>= 1;
}
if (table[x][0] == z) node = x;
else node = table[z][0];
}
ll mpznode = mp[make_pair(z, node)];
cout << routes[z] - mpznode * (N - mpznode - 1) + (N - mpznode - 1) << '\n';
}
}
ll dfs(int cur, int par) {
depth[cur] = depth[par] + 1;
table[cur][0] = par;
ll childcnt = 0;
for (auto& x : G[cur]) {
if (x == par) continue;
ll tmp = dfs(x, cur);
mp.insert(make_pair(make_pair(cur, x), tmp));
routes[cur] += childcnt * tmp;
childcnt += tmp;
}
routes[cur] += childcnt * (N - childcnt - 1);
mp.insert(make_pair(make_pair(cur, par), N - childcnt - 1));
return childcnt + 1;
}
void init() {
cin >> N >> Q;
for (int i = 1; i < N; i++) {
int x, y; cin >> x >> y;
G[x].emplace_back(y);
G[y].emplace_back(x);
}
dfs(1, 0);
for (int k = 1; k < 18; k++) {
for (int i = 1; i <= N; i++) {
table[i][k] = table[table[i][k - 1]][k - 1];
}
}
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
init();
while (Q--) solve();
}
'BOJ' 카테고리의 다른 글
[BOJ 25972 // C++] 도미노 무너트리기 (0) | 2022.11.16 |
---|---|
[BOJ 25991 // C++] Lots of Liquid (0) | 2022.11.16 |
[BOJ 1748 // C++] 수 이어 쓰기 1 (0) | 2022.11.14 |
[BOJ 25965 // C++] 미션 도네이션 (0) | 2022.11.14 |
[BOJ 2309 // C++] 일곱 난쟁이 (0) | 2022.11.13 |