※ 글쓴이는 취미로 코딩을 익혀보는 사람이라 정확하지 않은 내용을 담고 있을 수 있다 ※

 

이번에 볼 문제는 백준 22487번 문제인 Do use segment tree이다.
문제는 아래 링크를 확인하자.

https://www.acmicpc.net/problem/22487 

 

22487번: Do use segment tree

Given a tree with n (1 ≤ n ≤ 200,000) nodes and a list of q (1 ≤ q ≤ 100,000) queries, process the queries in order and output a value for each output query. The given tree is connected and each node on the tree has a weight wi (-10,000 ≤ wi ≤

www.acmicpc.net

주어지는 트리 위의 경로들에 대한 최대연속합 쿼리를 차례대로 처리하는 문제이다.

 

먼저 HLD를 이용해 어떤 두 점을 잇는 경로도 O(lgN)개의 구간으로 쪼갤 수 있게끔 트리를 분할해두자.

 

그리고 각 부분구간에 대하여 해당 구간의 최대연속합, 최대 prefix sum, 최대 postfix sum 및 전체 sum을 관리하는 lazy prop 세그먼트트리를 이용해 문제를 해결해주자.

 

lazy propagation을 구현할 때, 업데이트해야 할 가중치가 0이 될 수 있음에 유의하자.

 

또한, 글쓴이의 lazy prop 세그먼트트리는 경로 위 모든 노드의 가중치가 음수일 경우 빈 구간 0을 쿼리의 답으로 제시하지만 이 문제에서는 빈 구간의 경우를 제외한 답을 생각해야 한다. 이 경우 쿼리의 답은 전체 경로 위 노드의 가중치의 최댓값으로 구할 수 있음을 이용해 rangemax 세그먼트트리를 추가로 구현하는 것으로 답이 음수인 경우의 예외처리를 해줄 수 있다.

 

아래는 제출한 소스코드이다.

#include <iostream>
#include <vector>
#include <utility>
using namespace std;

int N, Q;
int W[200001];
int WW[200001];
vector<int> G[200001];

int heavychild[200001];
int dfs1(int cur, int par) {
	int ret = 1;
	int& hchild = heavychild[cur], hcnt = -1;
	for (auto& nxt : G[cur]) {
		if (nxt == par) continue;
		int tmp = dfs1(nxt, cur);
		ret += tmp;
		if (hcnt < tmp) hcnt = tmp, hchild = nxt;
	}

	return ret;
}

int dfsidx[200001];
int cidx[200001];
int cnth[200001];
int cpar[200001];
int cdepth[200001];
int dfi = 1, ci = 1;

void dfs2(int cur, int par, int lv) {
	int& id = dfsidx[cur] = dfi;
	WW[id] = W[cur];
	cidx[id] = ci, cdepth[id] = lv;
	int& hchild = heavychild[cur];
	if (hchild) {
		dfi++;
		cpar[dfi] = cpar[id];
		cnth[dfi] = cnth[id] + 1;
		dfs2(hchild, cur, lv);
	}
	lv++;
	for (auto& nxt : G[cur]) {
		if (nxt == par || nxt == hchild) continue;
		ci++, dfi++;
		cpar[dfi] = id;
		cnth[dfi] = 0;
		dfs2(nxt, cur, lv);
	}
}

int segL[524289], segR[524289], segT[524289], segV[524289], segM[524289];
int lazy[524289]; bool lazychk[524289];

void init(int L, int R, int sI) {
	if (L < R) {
		init(L, (L + R) / 2, sI * 2), init((L + R) / 2 + 1, R, sI * 2 + 1);
		segL[sI] = max(segL[sI * 2], segT[sI * 2] + segL[sI * 2 + 1]);
		segR[sI] = max(segR[sI * 2 + 1], segR[sI * 2] + segT[sI * 2 + 1]);
		segV[sI] = max(segR[sI * 2] + segL[sI * 2 + 1], max(segV[sI * 2], segV[sI * 2 + 1]));
		segT[sI] = segT[sI * 2] + segT[sI * 2 + 1];
		segM[sI] = max(segM[sI * 2], segM[sI * 2 + 1]);
	}
	else segL[sI] = segR[sI] = segT[sI] = segV[sI] = segM[sI] = WW[L];
}

void propagate(int L, int R, int sI) { // 업데이트값이 "0"일 수도 있음!! 주의!!
	lazychk[sI] = 0;
	if (lazy[sI] < 0) {
		segL[sI] = segR[sI] = segV[sI] = 0;
		segT[sI] = lazy[sI] * (R - L + 1);
	}
	else segL[sI] = segR[sI] = segV[sI] = segT[sI] = lazy[sI] * (R - L + 1);
	segM[sI] = lazy[sI];

	if (L < R) {
		lazychk[sI * 2] = lazychk[sI * 2 + 1] = 1;
		lazy[sI * 2] = lazy[sI * 2 + 1] = lazy[sI];
	}
}

void upd(int L, int R, int qL, int qR, int qVal, int sI) {
	if (lazychk[sI]) propagate(L, R, sI);
	if (R < qL || qR < L) return;
	if (qL <= L && R <= qR) {
		lazy[sI] = qVal, lazychk[sI] = 1;
		propagate(L, R, sI);
		return;
	}
	upd(L, (L + R) / 2, qL, qR, qVal, sI * 2), upd((L + R) / 2 + 1, R, qL, qR, qVal, sI * 2 + 1);
	segL[sI] = max(segL[sI * 2], segT[sI * 2] + segL[sI * 2 + 1]);
	segR[sI] = max(segR[sI * 2 + 1], segR[sI * 2] + segT[sI * 2 + 1]);
	segV[sI] = max(segR[sI * 2] + segL[sI * 2 + 1], max(segV[sI * 2], segV[sI * 2 + 1]));
	segT[sI] = segT[sI * 2] + segT[sI * 2 + 1];
	segM[sI] = max(segM[sI * 2], segM[sI * 2 + 1]);
}

struct qrycontainer {
	int sL, sR, sV, sT, sM;
	qrycontainer() {};
	qrycontainer(int sL, int sR, int sV, int sT, int sM) {
		this->sL = sL, this->sR = sR, this->sV = sV, this->sT = sT, this->sM = sM;
	}
};

qrycontainer qry(int L, int R, int qL, int qR, int sI) {
	if (lazychk[sI]) propagate(L, R, sI);
	if (R < qL || qR < L) return qrycontainer(0, 0, 0, 0, -1000000007);
	if (qL <= L && R <= qR) return qrycontainer(segL[sI], segR[sI], segV[sI], segT[sI], segM[sI]);
	qrycontainer q1 = qry(L, (L + R) / 2, qL, qR, sI * 2), q2 = qry((L + R) / 2 + 1, R, qL, qR, sI * 2 + 1);
	int retL = max(q1.sL, q1.sT + q2.sL), retR = max(q2.sR, q1.sR + q2.sT), retV = max(q1.sR + q2.sL, max(q1.sV, q2.sV)), retT = q1.sT + q2.sT, retM = max(q1.sM, q2.sM);
	return qrycontainer(retL, retR, retV, retT, retM);
}

int main() {
	ios::sync_with_stdio(0);
	cin.tie(0);

	cin >> N >> Q;
	for (int i = 1; i <= N; i++) cin >> W[i];
	for (int i = 1; i < N; i++) {
		int x, y; cin >> x >> y;
		G[x].emplace_back(y);
		G[y].emplace_back(x);
	}

	dfs1(1, 0);
	dfs2(1, 0, 0);
	init(1, N, 1);

	while (Q--) {
		int q, x, y, w; cin >> q >> x >> y >> w;
		x = dfsidx[x], y = dfsidx[y];
		if (cdepth[x] < cdepth[y]) swap(x, y);
		if (q == 1) {
			while (cdepth[x] > cdepth[y]) {
				upd(1, N, x - cnth[x], x, w, 1);
				x = cpar[x];
			}
			while (cidx[x] != cidx[y]) {
				upd(1, N, x - cnth[x], x, w, 1);
				upd(1, N, y - cnth[y], y, w, 1);
				x = cpar[x], y = cpar[y];
			}
			if (x < y) upd(1, N, x, y, w, 1);
			else upd(1, N, y, x, w, 1);
		}
		else {
			int mx = -1000000007;
			int xL = 0, xR = 0, xV = 0, xT = 0, yL = 0, yR = 0, yV = 0, yT = 0;
			while (cdepth[x] > cdepth[y]) {
				auto qc = qry(1, N, x - cnth[x], x, 1);
				int xl = max(qc.sL, qc.sT + xL);
				int xr = max(xR, qc.sR + xT);
				int xv = max(qc.sR + xL, max(xV, qc.sV));
				int xt = xT + qc.sT;
				xL = xl, xR = xr, xV = xv, xT = xt;
				mx = max(mx, qc.sM);

				x = cpar[x];
			}
			while (cidx[x] != cidx[y]) {
				auto qcx = qry(1, N, x - cnth[x], x, 1);
				int xl = max(qcx.sL, qcx.sT + xL);
				int xr = max(xR, qcx.sR + xT);
				int xv = max(qcx.sR + xL, max(xV, qcx.sV));
				int xt = xT + qcx.sT;
				xL = xl, xR = xr, xV = xv, xT = xt;
				mx = max(mx, qcx.sM);

				auto qcy = qry(1, N, y - cnth[y], y, 1);
				int yl = max(qcy.sL, qcy.sT + yL);
				int yr = max(yR, qcy.sR + yT);
				int yv = max(qcy.sR + yL, max(yV, qcy.sV));
				int yt = yT + qcy.sT;
				yL = yl, yR = yr, yV = yv, yT = yt;
				mx = max(mx, qcy.sM);

				x = cpar[x], y = cpar[y];
			}
			if (x < y) {
				auto qc = qry(1, N, x, y, 1);
				int yl = max(qc.sL, qc.sT + yL);
				int yr = max(yR, qc.sR + yT);
				int yv = max(qc.sR + yL, max(yV, qc.sV));
				int yt = yT + qc.sT;
				yL = yl, yR = yr, yV = yv, yT = yt;
				mx = max(mx, qc.sM);

				if (mx < 0) cout << mx << '\n';
				else cout << max(xL + yL, max(xV, yV)) << '\n';
			}
			else {
				auto qc = qry(1, N, y, x, 1);
				int xl = max(qc.sL, qc.sT + xL);
				int xr = max(xR, qc.sR + xT);
				int xv = max(qc.sR + xL, max(xV, qc.sV));
				int xt = xT + qc.sT;
				xL = xl, xR = xr, xV = xv, xT = xt;
				mx = max(mx, qc.sM);

				if (mx < 0) cout << mx << '\n';
				else cout << max(xL + yL, max(xV, yV)) << '\n';
			}
		}
	}
}
728x90

'BOJ' 카테고리의 다른 글

[BOJ 16211 // C++] 백채원  (0) 2023.09.18
[BOJ 3024 // C++] 마라톤 틱택토  (0) 2023.09.17
[BOJ 16206 // C++] 롤케이크  (0) 2023.09.15
[BOJ 16207 // C++] 직사각형  (0) 2023.09.14
[BOJ 16210 // C++] DSHS Bank  (0) 2023.09.13

+ Recent posts