※ 글쓴이는 취미로 코딩을 익혀보는 사람이라 정확하지 않은 내용을 담고 있을 수 있다 ※
이번에 볼 문제는 백준 22487번 문제인 Do use segment tree이다.
문제는 아래 링크를 확인하자.
https://www.acmicpc.net/problem/22487
22487번: Do use segment tree
Given a tree with n (1 ≤ n ≤ 200,000) nodes and a list of q (1 ≤ q ≤ 100,000) queries, process the queries in order and output a value for each output query. The given tree is connected and each node on the tree has a weight wi (-10,000 ≤ wi ≤
www.acmicpc.net
주어지는 트리 위의 경로들에 대한 최대연속합 쿼리를 차례대로 처리하는 문제이다.
먼저 HLD를 이용해 어떤 두 점을 잇는 경로도 O(lgN)개의 구간으로 쪼갤 수 있게끔 트리를 분할해두자.
그리고 각 부분구간에 대하여 해당 구간의 최대연속합, 최대 prefix sum, 최대 postfix sum 및 전체 sum을 관리하는 lazy prop 세그먼트트리를 이용해 문제를 해결해주자.
lazy propagation을 구현할 때, 업데이트해야 할 가중치가 0이 될 수 있음에 유의하자.
또한, 글쓴이의 lazy prop 세그먼트트리는 경로 위 모든 노드의 가중치가 음수일 경우 빈 구간 0을 쿼리의 답으로 제시하지만 이 문제에서는 빈 구간의 경우를 제외한 답을 생각해야 한다. 이 경우 쿼리의 답은 전체 경로 위 노드의 가중치의 최댓값으로 구할 수 있음을 이용해 rangemax 세그먼트트리를 추가로 구현하는 것으로 답이 음수인 경우의 예외처리를 해줄 수 있다.
아래는 제출한 소스코드이다.
#include <iostream>
#include <vector>
#include <utility>
using namespace std;
int N, Q;
int W[200001];
int WW[200001];
vector<int> G[200001];
int heavychild[200001];
int dfs1(int cur, int par) {
int ret = 1;
int& hchild = heavychild[cur], hcnt = -1;
for (auto& nxt : G[cur]) {
if (nxt == par) continue;
int tmp = dfs1(nxt, cur);
ret += tmp;
if (hcnt < tmp) hcnt = tmp, hchild = nxt;
}
return ret;
}
int dfsidx[200001];
int cidx[200001];
int cnth[200001];
int cpar[200001];
int cdepth[200001];
int dfi = 1, ci = 1;
void dfs2(int cur, int par, int lv) {
int& id = dfsidx[cur] = dfi;
WW[id] = W[cur];
cidx[id] = ci, cdepth[id] = lv;
int& hchild = heavychild[cur];
if (hchild) {
dfi++;
cpar[dfi] = cpar[id];
cnth[dfi] = cnth[id] + 1;
dfs2(hchild, cur, lv);
}
lv++;
for (auto& nxt : G[cur]) {
if (nxt == par || nxt == hchild) continue;
ci++, dfi++;
cpar[dfi] = id;
cnth[dfi] = 0;
dfs2(nxt, cur, lv);
}
}
int segL[524289], segR[524289], segT[524289], segV[524289], segM[524289];
int lazy[524289]; bool lazychk[524289];
void init(int L, int R, int sI) {
if (L < R) {
init(L, (L + R) / 2, sI * 2), init((L + R) / 2 + 1, R, sI * 2 + 1);
segL[sI] = max(segL[sI * 2], segT[sI * 2] + segL[sI * 2 + 1]);
segR[sI] = max(segR[sI * 2 + 1], segR[sI * 2] + segT[sI * 2 + 1]);
segV[sI] = max(segR[sI * 2] + segL[sI * 2 + 1], max(segV[sI * 2], segV[sI * 2 + 1]));
segT[sI] = segT[sI * 2] + segT[sI * 2 + 1];
segM[sI] = max(segM[sI * 2], segM[sI * 2 + 1]);
}
else segL[sI] = segR[sI] = segT[sI] = segV[sI] = segM[sI] = WW[L];
}
void propagate(int L, int R, int sI) { // 업데이트값이 "0"일 수도 있음!! 주의!!
lazychk[sI] = 0;
if (lazy[sI] < 0) {
segL[sI] = segR[sI] = segV[sI] = 0;
segT[sI] = lazy[sI] * (R - L + 1);
}
else segL[sI] = segR[sI] = segV[sI] = segT[sI] = lazy[sI] * (R - L + 1);
segM[sI] = lazy[sI];
if (L < R) {
lazychk[sI * 2] = lazychk[sI * 2 + 1] = 1;
lazy[sI * 2] = lazy[sI * 2 + 1] = lazy[sI];
}
}
void upd(int L, int R, int qL, int qR, int qVal, int sI) {
if (lazychk[sI]) propagate(L, R, sI);
if (R < qL || qR < L) return;
if (qL <= L && R <= qR) {
lazy[sI] = qVal, lazychk[sI] = 1;
propagate(L, R, sI);
return;
}
upd(L, (L + R) / 2, qL, qR, qVal, sI * 2), upd((L + R) / 2 + 1, R, qL, qR, qVal, sI * 2 + 1);
segL[sI] = max(segL[sI * 2], segT[sI * 2] + segL[sI * 2 + 1]);
segR[sI] = max(segR[sI * 2 + 1], segR[sI * 2] + segT[sI * 2 + 1]);
segV[sI] = max(segR[sI * 2] + segL[sI * 2 + 1], max(segV[sI * 2], segV[sI * 2 + 1]));
segT[sI] = segT[sI * 2] + segT[sI * 2 + 1];
segM[sI] = max(segM[sI * 2], segM[sI * 2 + 1]);
}
struct qrycontainer {
int sL, sR, sV, sT, sM;
qrycontainer() {};
qrycontainer(int sL, int sR, int sV, int sT, int sM) {
this->sL = sL, this->sR = sR, this->sV = sV, this->sT = sT, this->sM = sM;
}
};
qrycontainer qry(int L, int R, int qL, int qR, int sI) {
if (lazychk[sI]) propagate(L, R, sI);
if (R < qL || qR < L) return qrycontainer(0, 0, 0, 0, -1000000007);
if (qL <= L && R <= qR) return qrycontainer(segL[sI], segR[sI], segV[sI], segT[sI], segM[sI]);
qrycontainer q1 = qry(L, (L + R) / 2, qL, qR, sI * 2), q2 = qry((L + R) / 2 + 1, R, qL, qR, sI * 2 + 1);
int retL = max(q1.sL, q1.sT + q2.sL), retR = max(q2.sR, q1.sR + q2.sT), retV = max(q1.sR + q2.sL, max(q1.sV, q2.sV)), retT = q1.sT + q2.sT, retM = max(q1.sM, q2.sM);
return qrycontainer(retL, retR, retV, retT, retM);
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cin >> N >> Q;
for (int i = 1; i <= N; i++) cin >> W[i];
for (int i = 1; i < N; i++) {
int x, y; cin >> x >> y;
G[x].emplace_back(y);
G[y].emplace_back(x);
}
dfs1(1, 0);
dfs2(1, 0, 0);
init(1, N, 1);
while (Q--) {
int q, x, y, w; cin >> q >> x >> y >> w;
x = dfsidx[x], y = dfsidx[y];
if (cdepth[x] < cdepth[y]) swap(x, y);
if (q == 1) {
while (cdepth[x] > cdepth[y]) {
upd(1, N, x - cnth[x], x, w, 1);
x = cpar[x];
}
while (cidx[x] != cidx[y]) {
upd(1, N, x - cnth[x], x, w, 1);
upd(1, N, y - cnth[y], y, w, 1);
x = cpar[x], y = cpar[y];
}
if (x < y) upd(1, N, x, y, w, 1);
else upd(1, N, y, x, w, 1);
}
else {
int mx = -1000000007;
int xL = 0, xR = 0, xV = 0, xT = 0, yL = 0, yR = 0, yV = 0, yT = 0;
while (cdepth[x] > cdepth[y]) {
auto qc = qry(1, N, x - cnth[x], x, 1);
int xl = max(qc.sL, qc.sT + xL);
int xr = max(xR, qc.sR + xT);
int xv = max(qc.sR + xL, max(xV, qc.sV));
int xt = xT + qc.sT;
xL = xl, xR = xr, xV = xv, xT = xt;
mx = max(mx, qc.sM);
x = cpar[x];
}
while (cidx[x] != cidx[y]) {
auto qcx = qry(1, N, x - cnth[x], x, 1);
int xl = max(qcx.sL, qcx.sT + xL);
int xr = max(xR, qcx.sR + xT);
int xv = max(qcx.sR + xL, max(xV, qcx.sV));
int xt = xT + qcx.sT;
xL = xl, xR = xr, xV = xv, xT = xt;
mx = max(mx, qcx.sM);
auto qcy = qry(1, N, y - cnth[y], y, 1);
int yl = max(qcy.sL, qcy.sT + yL);
int yr = max(yR, qcy.sR + yT);
int yv = max(qcy.sR + yL, max(yV, qcy.sV));
int yt = yT + qcy.sT;
yL = yl, yR = yr, yV = yv, yT = yt;
mx = max(mx, qcy.sM);
x = cpar[x], y = cpar[y];
}
if (x < y) {
auto qc = qry(1, N, x, y, 1);
int yl = max(qc.sL, qc.sT + yL);
int yr = max(yR, qc.sR + yT);
int yv = max(qc.sR + yL, max(yV, qc.sV));
int yt = yT + qc.sT;
yL = yl, yR = yr, yV = yv, yT = yt;
mx = max(mx, qc.sM);
if (mx < 0) cout << mx << '\n';
else cout << max(xL + yL, max(xV, yV)) << '\n';
}
else {
auto qc = qry(1, N, y, x, 1);
int xl = max(qc.sL, qc.sT + xL);
int xr = max(xR, qc.sR + xT);
int xv = max(qc.sR + xL, max(xV, qc.sV));
int xt = xT + qc.sT;
xL = xl, xR = xr, xV = xv, xT = xt;
mx = max(mx, qc.sM);
if (mx < 0) cout << mx << '\n';
else cout << max(xL + yL, max(xV, yV)) << '\n';
}
}
}
}
'BOJ' 카테고리의 다른 글
[BOJ 16211 // C++] 백채원 (0) | 2023.09.18 |
---|---|
[BOJ 3024 // C++] 마라톤 틱택토 (0) | 2023.09.17 |
[BOJ 16206 // C++] 롤케이크 (0) | 2023.09.15 |
[BOJ 16207 // C++] 직사각형 (0) | 2023.09.14 |
[BOJ 16210 // C++] DSHS Bank (0) | 2023.09.13 |