※ 글쓴이는 취미로 코딩을 익혀보는 사람이라 정확하지 않은 내용을 담고 있을 수 있다 ※
이번에 볼 문제는 백준 11658번 문제인 구간 합 구하기 3이다.
문제는 아래 링크를 확인하자.
https://www.acmicpc.net/problem/11658
11658번: 구간 합 구하기 3
첫째 줄에 표의 크기 N과 수행해야 하는 연산의 수 M이 주어진다. (1 ≤ N ≤ 1024, 1 ≤ M ≤ 100,000) 둘째 줄부터 N개의 줄에는 표에 채워져있는 수가 1행부터 차례대로 주어진다. 다음 M개의 줄에는
www.acmicpc.net
2차원 세그먼트트리를 구현해보기 위해 이 문제를 풀어보았다.
2차원 세그먼트트리는 1차원 세그먼트트리를 각 노드로 갖는 자료구조로, 한 축으로 1차원 세그먼트트리를 만들듯이 구간을 쪼개고, 각 구간별로 그에 대응되는 다른 축방향 1차원 세그먼트트리를 하나씩 만드는 것으로 구현할 수 있다.
아래의 코드는 제출시 실행시간이 (996ms~시간초과)로 실행속도가 좋지 못하지만 지금의 시도를 이곳에 남겨둔다.
#include <iostream>
using namespace std;
int N, M;
int arr[1025][1025];
struct node_Col {
int val;
node_Col* l = nullptr;
node_Col* r = nullptr;
};
struct node_Row {
node_Row* l = nullptr;
node_Row* r = nullptr;
node_Col* rt = nullptr;
};
node_Row* root;
int init_Col1(node_Col* cur, int row, int L, int R) {
if (L == R) return cur->val = arr[row][L];
cur->l = new node_Col();
cur->r = new node_Col();
return cur->val = init_Col1(cur->l, row, L, (L + R) / 2) + init_Col1(cur->r, row, (L + R) / 2 + 1, R);
}
int init_Col2(node_Col* cur, node_Col* x, node_Col* y, int L, int R) {
if (L == R) return cur->val = x->val + y->val;
cur->l = new node_Col();
cur->r = new node_Col();
return cur->val = init_Col2(cur->l, x->l, y->l, L, (L + R) / 2) + init_Col2(cur->r, x->r, y->r, (L + R) / 2 + 1, R);
}
void init_Row(node_Row* cur, int L, int R) {
cur->rt = new node_Col();
if (L == R) init_Col1(cur->rt, L, 1, N);
else {
cur->l = new node_Row();
cur->r = new node_Row();
init_Row(cur->l, L, (L + R) / 2);
init_Row(cur->r, (L + R) / 2 + 1, R);
init_Col2(cur->rt, cur->l->rt, cur->r->rt, 1, N);
}
}
int upd_Col1(node_Col* cur, int qC, int qVal, int L, int R) {
if (qC < L || R < qC) return cur->val;
if (L == R) return cur->val = qVal;
return cur->val = upd_Col1(cur->l, qC, qVal, L, (L + R) / 2) + upd_Col1(cur->r, qC, qVal, (L + R) / 2 + 1, R);
}
int upd_Col2(node_Col* cur, node_Col* x, node_Col* y, int qC, int L, int R) {
if (qC < L || R < qC) return cur->val;
if (L == R) return cur->val = x->val + y->val;
return cur->val = upd_Col2(cur->l, x->l, y->l, qC, L, (L + R) / 2) + upd_Col2(cur->r, x->r, y->r, qC, (L + R) / 2 + 1, R);;
}
void upd_Row(node_Row* cur, int qR, int qC, int qVal, int L, int R) {
if (qR < L || R < qR) return;
if (L == R) {
upd_Col1(cur->rt, qC, qVal, 1, N);
return;
}
upd_Row(cur->l, qR, qC, qVal, L, (L + R) / 2);
upd_Row(cur->r, qR, qC, qVal, (L + R) / 2 + 1, R);
upd_Col2(cur->rt, cur->l->rt, cur->r->rt, qC, 1, N);
}
int query_Col(node_Col* cur, int qCL, int qCR, int L, int R) {
if (qCR < L || R < qCL) return 0;
if (qCL <= L && R <= qCR) return cur->val;
return query_Col(cur->l, qCL, qCR, L, (L + R) / 2) + query_Col(cur->r, qCL, qCR, (L + R) / 2 + 1, R);
}
int query_Row(node_Row* cur, int qRL, int qRR, int qCL, int qCR, int L, int R) {
if (qRR < L || R < qRL) return 0;
if (qRL <= L && R <= qRR) return query_Col(cur->rt, qCL, qCR, 1, N);
return query_Row(cur->l, qRL, qRR, qCL, qCR, L, (L + R) / 2) + query_Row(cur->r, qRL, qRR, qCL, qCR, (L + R) / 2 + 1, R);
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cin >> N >> M;
for (int r = 1; r <= N; r++) {
for (int c = 1; c <= N; c++) {
cin >> arr[r][c];
}
}
root = new node_Row();
init_Row(root, 1, N);
while (M--) {
int q; cin >> q;
if (q) {
int r1, c1, r2, c2; cin >> r1 >> c1 >> r2 >> c2;
cout << query_Row(root, r1, r2, c1, c2, 1, N) << '\n';
}
else {
int r, c, val; cin >> r >> c >> val;
upd_Row(root, r, c, val, 1, N);
}
}
}
2차원 세그먼트트리와는 별개로, prefix sum을 행마다 따로 관리한다면 0번과 1번 쿼리에 대하여 각각 많아야 1024회의 연산만을 하면 되므로 제한시간 내로 문제를 해결할 수 있다.
아래는 제출한 소스코드이다.
#include <iostream>
using namespace std;
int arr[1025][1025];
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
int N, M; cin >> N >> M;
for (int r = 1; r <= N; r++) {
for (int c = 1; c <= N; c++) {
int x; cin >> x;
arr[r][c] = arr[r][c - 1] + x;
}
}
while (M--) {
int q; cin >> q;
if (q) {
int ans = 0;
int r1, c1, r2, c2; cin >> r1 >> c1 >> r2 >> c2;
c1--;
for (int r = r1; r <= r2; r++) {
ans += (arr[r][c2] - arr[r][c1]);
}
cout << ans << '\n';
}
else {
int r1, c1, val; cin >> r1 >> c1 >> val;
int d = val - arr[r1][c1] + arr[r1][c1 - 1];
while (c1 <= N)arr[r1][c1++] += d;
}
}
}
728x90
'BOJ' 카테고리의 다른 글
[BOJ 24498 // C++] blobnom (0) | 2022.03.05 |
---|---|
[BOJ 24511 // C++] queuestack (0) | 2022.03.04 |
[BOJ 24510 // C++] 시간복잡도를 배운 도도 (0) | 2022.03.02 |
[BOJ 1431 // C++] 시리얼 번호 (0) | 2022.03.01 |
[BOJ 7596 // C++] MP3 Songs (0) | 2022.02.28 |