※ 글쓴이는 취미로 코딩을 익혀보는 사람이라 정확하지 않은 내용을 담고 있을 수 있다 ※

 

이번에 볼 문제는 백준 18263번 문제인 Milk Visits이다.
문제는 아래 링크를 확인하자.

https://www.acmicpc.net/problem/18263 

 

18263번: Milk Visits

Farmer John is planning to build $N$ ($1 \leq N \leq 10^5$) farms that will be connected by $N-1$ roads, forming a tree (i.e., all farms are reachable from each-other, and there are no cycles). Each farm contains a cow with an integer type $T_i$ between $1

www.acmicpc.net

HLD(Heavy-Light Decomposition)로 트리를 분해해두고, 구간합 세그먼트 트리(segment tree)와 오프라인 쿼리(offline query) 테크닉으로 각 소의 타입별로 가중치를 부여했다 지웠다 하는 것으로 문제를 해결할 수 있다.

구체적으로, 두 노드를 잇는 경로 사이의 노드의 가중치를 모두 합했을 때 양수이면 1, 그렇지 않으면 0으로 생각할 수 있다.

 

아래는 제출한 소스코드이다.

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

vector<int> G[100001];

int heavychild[100001];
int dfs(int current, int parent) {
    int childcnt = 1;
    int hchild = -1;
    int hcount = -1;
    for (auto node : G[current]) {
        if (node == parent) continue;
        int temp = dfs(node, current);
        childcnt += temp;
        if (temp > hcount) {
            hcount = temp, hchild = node;
        }
    }
    heavychild[current] = hchild;
    return childcnt;
}

int nodeidx[100001];
int chainidx[100001];
int cparent[100001];
int cnth[100001];
int cdepth[100001];
int nidx = 1, cidx = 1;
void hld(int current, int parent, int cp, int cn, int cd) {
    nodeidx[current] = nidx;
    chainidx[nidx] = cidx;
    cparent[nidx] = cp;
    cnth[nidx] = cn;
    cdepth[nidx] = cd;

    int hchild = heavychild[current];
    if (hchild > 0) {
        nidx++;
        hld(hchild, current, cp, cn + 1, cd);
    }
    cd++;
    for (auto node : G[current]) {
        if (node == parent || node == hchild) continue;
        nidx++; cidx++;
        hld(node, current, nodeidx[current], 0, cd);
    }
}

vector<int> cowtype[100001];

struct query {
    int idx;
    int x, y, c;
    int ans;
};

bool qccomp(query q1, query q2) {
    return q1.c < q2.c;
}

bool qicomp(query q1, query q2) {
    return q1.idx < q2.idx;
}

vector<query> queries;

int seg[262145];

void update(int L, int R, int qI, int qVal) {
    int sI = 1;
    while (L < R) {
        seg[sI] += qVal;
        int mid = (L + R) / 2;
        if (qI <= mid) {
            R = mid;
            sI = sI * 2;
        }
        else {
            L = mid + 1;
            sI = sI * 2 + 1;
        }
    }
    seg[sI] += qVal;
}

int rangesum(int L, int R, int qL, int qR, int sI) {
    if (R < qL || qR < L) return 0;
    if (qL <= L && R <= qR) return seg[sI];
    return rangesum(L, (L + R) / 2, qL, qR, sI * 2) + rangesum((L + R) / 2 + 1, R, qL, qR, sI * 2 + 1);
}

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);

    int N, M; cin >> N >> M;
    for (int i = 1; i <= N; i++) {
        int x; cin >> x;
        cowtype[x].push_back(i);
    }

    for (int i = 1; i < N; i++) {
        int x, y; cin >> x >> y;
        G[x].push_back(y);
        G[y].push_back(x);
    }

    dfs(1, 0);
    hld(1, 0, 0, 0, 0);
    
    for (int i = 1;i <= M;i++) {
        int x, y, c; cin >> x >> y >> c;
        query temp; temp.idx = i, temp.x = x, temp.y = y, temp.c = c;
        queries.push_back(temp);
    }

    sort(queries.begin(), queries.end(), qccomp);

    int oldc = 0;
    for (int i = 0; i < M; i++) {
        query q = queries[i];
        if (q.c != oldc) {
            for (auto c : cowtype[oldc]) update(1, N, nodeidx[c], -1);
            for (auto c : cowtype[q.c]) update(1, N, nodeidx[c], 1);
            oldc = q.c;
        }
        
        int x = nodeidx[q.x], y = nodeidx[q.y];

        int temp = 0;

        if (cdepth[x] != cdepth[y]) {
            if (cdepth[x] < cdepth[y]) swap(x, y);
            while (cdepth[x] > cdepth[y]) {
                temp += rangesum(1, N, x - cnth[x], x, 1);
                x = cparent[x];
            }
        }

        while (chainidx[x] != chainidx[y]) {
            temp += rangesum(1, N, x - cnth[x], x, 1) + rangesum(1, N, y - cnth[y], y, 1);
            x = cparent[x], y = cparent[y];
        }

        if (x > y) swap(x, y);
        temp += rangesum(1, N, x, y, 1);

        if (temp > 0) queries[i].ans = 1;
        else queries[i].ans = 0;
    }

    sort(queries.begin(), queries.end(), qicomp);
    
    for (auto q : queries) {
        cout << q.ans;
    }
}
728x90

+ Recent posts