(+) 9345번 디지털 비디오 디스크(DVDs) 문제의 코드만 보시려면 본 게시물 최하단으로 바로 내려가시면 됩니다.
이 글은 세그먼트 트리에 대한 개념 설명과 기본 문제를 해결한 후, 최종적으로 9345번 문제에 대한 설명을 진행합니다.
1. 세그먼트 트리를 사용하는 문제 상황
이진트리의 활용 방법 중 하나인 세그먼트 트리(Segment Tree)를 이용하여 해결할 수 있는 문제이다.
세그먼트 트리를 활용하여 해결할 수 있는 문제들은 다양한데, 뽑아보자면 다음과 같다.
2268번: 수들의 합
첫째 줄에는 N(1≤N≤1,000,000), M(1≤M≤1,000,000)이 주어진다. M은 수행한 명령의 개수이며 다음 M개의 줄에는 수행한 순서대로 함수의 목록이 주어진다. 첫 번째 숫자는 어느 함수를 사용했는지를 ��
www.acmicpc.net
11505번: 구간 곱 구하기
첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 곱을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄�
www.acmicpc.net
9345번: 디지털 비디오 디스크(DVDs)
문제 최근 유튜브와 같은 온라인 비디오 스트리밍 서비스 때문에 DVD 대여점들이 자취를 감추고 있다. 이러한 어려운 상황 속에서, DVD 대여점 주인들은 실낱같은 희망을 잡고자 인기있는 N개의 DV
www.acmicpc.net
위 문제들의 공통점은 N개의 수를 가진 배열에서 특정 위치의 숫자를 변경하거나, 특정 구간의 정보를 요구하는 M개의 쿼리를 해결하는 문제이다.
위에 소개한 문제들 중 수들의 합(2268번) 문제로 예를 들어 보자면,
위와 같은 방식의 문제를 주로 세그먼트 트리로 해결한다.
'굳이 특정한 자료구조를 사용하지 않고 해결할 수 있지 않는가?' 라는 의문이 들 수 있지만,
위 문제에서는 N이 100만까지, 그리고 쿼리의 수 M이 100만까지이므로, 단순히 쿼리가 들어올 때마다 구간에 대한 합을 구하는 방법으로는 O(NM)의 시간복잡도를 가지게 되므로, 시간초과를 피할 수는 없다.
(여기서 O(NM)으로 해결하는 방법은 prefix sum을 이용하여 구간 합에 대한 정보를 매 쿼리마다 유지하는 방법이다)
결국 쿼리는 M개만큼 무조건 해결해야하니 N개의 수에 대한 구간 정보를 O(logN)의 시간만에 해결하여 O(MlogN)의 시간복잡도를 가지는 방법을 필요로 하는데, 이는 세그먼트 트리를 이용하면 된다.
2. 세그먼트 트리
그렇다면 세그먼트 트리가 무엇일까?
세그먼트 트리는 이진 트리의 구조를 활용하는 방법이다.
다음의 그림을 보자.
위 그림에서 트리의 각 노드에 적힌 숫자는 배열의 index라고 생각하면 된다. 즉, 배열 Array에 대하여 Array[1]은 루트 노드이고, 루트의 자식은 (Array[1 * 2] = Array[2])와 (Array[1 * 2 + 1] = Array[3])이 된다.
더 확장하자면, Array[i]의 자식은 각각 Array[i * 2]와 Array[i * 2 + 1]이 되며,
이 두 자식은 Array[i]에 붙어있는 두 개의 서브트리의 루트 노드가 되는 것이다.
하지만, 세그먼트 트리에서는 Array[1]부터 Array[N]까지 배열의 원소를 차곡차곡 넣는 것이 아니라, 리프 노드에만 배열의 원소를 삽입해 둔다.
리프 노드를 제외한, 즉 자식이 존재하는 모든 노드에는 자신의 밑에 달린 서브트리에 대한 정보를 담게 된다.
말로 표현하기 힘드니 수들의 합(2268번) 문제의 상황을 그림으로 설명해 보자면,
와 같이 나타낼 수 있는데, 여기서 리프 노드에는 배열 Array = {2, 4, 4, 7, 3, 5, 7, 1}가 순서대로 저장되어 있고, 부모 노드에는 자식 노드의 합을 가지고 있는 모습이다.
우선 데이터를 저장하는 아이디어는 비교적 간단하다. 트리의 리프에는 배열의 값들을, 그 외에는 자식 노드의 합을 저장해주면 되니까.
다른 유형, 예를 들어 구간 최소, 최대, 곱 등등도 그냥 해당 계산을 수행해서 저장만 해주면 된다.
그렇다면, 실제 코드에서는 어떻게 구현해야 할까?
1) 세그먼트 트리의 초기화(Initialization)
먼저 생각해봐야 할 것은 저장 공간이 얼마나 되는가? 인데, 배열의 원소가 N개 라고 한다면, 리프 노드의 수는
2^k(2의 거듭제곱 꼴) >= N
을 만족하는 최초의(최소의) 2^k개 만큼은 필요하다.
(물론 차이만큼은 사용하지 않는 공간이겠지만 재귀 함수를 사용할 때 일관된 표현을 위해 위와 같이 사용한다)
그리고, 나머지 노드들은 (2^k - 1)개 만큼 필요한데, 이는 위의 그림을 보아도 되고, 이진 트리의 특징을 생각해보면 이해하기 쉽다.
결국 2^k + 2^k - 1 = 2^(k+1) - 1개 만큼은 필요하며, 편의상 2*2^k개 만큼은 필요하다고 생각하면 된다.
즉 N개를 넘는 최초의 2^k개의 2배만큼은 필요하다. (말이 뭔가 애매하다..)
코드로는 간단히 구할 수 있지 싶다.
int Array[2100000];
int bound = 1;
while (bound < N)
bound *= 2;
여기서 bound는 2^k이고, N이 100만이라면 bound는 1048576일테고 Array는 (2097152 - 1)개 만큼은 공간이 있어야 하지만 귀찮아서 210만으로 잡겠다..
bound를 굳이 코드로 구한 이유는 이제부터 설명하겠다. 다시 그림으로 돌아가보자.
위 그림처럼 만약 N이 8이었다면, bound 또한 8이 나올 것이고, Array는 적어도 15개의 공간은 필요하다.
재미있는 사실은, 리프노드의 시작 index가 8이고 bound도 8이라는 점이다.
따라서, 그냥 bound부터 원소를 집어넣으면 되는데,
for (int i = bound; i < bound + N; i++)
scanf("%d", &Array[i]); // bound(리프 노드의 시작)부터 N개 만큼 원소 저장
for (int i = bound + N; i < 2 * bound; i++)
Array[i] = 0; // 리프의 나머지 빈 배열 초기화
주로 stdin으로 받을 테니 이런 식으로 구현하면 되겠다. 수들의 합(2268번) 문제에서는 전역변수로 Array를 선언하는 경우 상관은 없지만, 밑의 반복문은 꼭 상기해두어야하는 부분이다.
가령 구간 합이 아니라 구간 최소라고 한다면 0이 아니라 다른 값이 들어가야 하고, 이 경우에는 반드시 밑의 반복문이 필요하다.
이제 리프노드에는 배열 값들을 저장하였으니, 나머지 노드에 대한 처리만 해주면 되겠다.
for (int i = bound - 1; i > 0; i--)
Array[i] = Array[i * 2] + Array[i * 2 + 1]; // 두 자식의 합을 저장
사실 수들의 합(2268번) 문제를 해결할 때에는 배열 값이 모두 0으로 시작하므로 굳이 필요없지만, 배열을 주고 시작하는 문제도 많으니까, 알아두어야 하는 부분이다.
2) Update, Find
세그먼트 트리를 만들어 두었으니, 필요한 함수를 구현해보자.
우선 세그먼트 트리는 Update, Find 함수 두 가지가 핵심이다. (이름은 마음대로 지었다)
Update는 배열의 특정 원소의 값이 변경되는 경우, 그 원소의 부모부터 루트까지의 값들을 갱신해주는 함수이고,
Find는 루트부터 시작해 원하는 구간의 값(합, 곱, 최대, 최소...)을 찾아내는 함수이다.
그럼 Update부터 살펴보자. 우선 그림을 다시 봐야겠는데,
만약 위의 트리에서 Array[9]의 값이 4에서 9로 바뀌었다고 가정해보자.
그럼 밑의 그림처럼 갱신이 되어야 할 텐데,
초록색 글씨를 따라가 보자.
Array[9]를 9로 갱신하면, 그의 부모인 Array[4], Array[2], Array[1]이 순서대로 바뀌었고 나머지 노드는 건드리지 않았다.
만약 Array[12]를 갱신했다면 Array[6], Array[3], Array[1]이 바뀔 것이고 ... Array[15]를 갱신하면,
Array[15 / 2] = Array[7], Array[7 / 2] = Array[3], Array[3 / 2] = Array[1]이 바뀌게 되네..?
오 그렇다면, Array[K]의 값을 갱신하면
Array[K / 2], Array[K / 2^2], Array[K / 2^3] ... 의 값이 갱신이 되겠네!
사실 대단한 발견같지만, 이진트리의 특징을 생각해보면 그렇게 대단한 부분은 아니다..
아무튼, 코드로도 구현하기 쉬울 것 같다.
void Update(int K, int num) // K번째 원소를 num으로 갱신
{
int idx = bound + K;
Array[idx] = num;
idx /= 2;
while(idx > 0)
{
Array[idx] = Array[idx * 2] + Array[idx * 2 + 1];
idx /= 2;
}
}
사실 앞서 설명한 초기화 부분을 끌어다 쓴 것 뿐이다.
여기서 bound는 앞서 설명한 리프노드의 시작이고, K가 0부터 N - 1까지의 값이라고 할 때를 가정하였다.
시간복잡도는 가볍게 O(logN)임을 알 수 있다.
이제 핵심인 Find 함수를 알아보자.
나는 구간 [9, 14]의 합을 알고 싶다. 그런데 구현해놓은 세그먼트 트리의 노드들을 보면,
[9, 14]의 합을 담고 있는 노드는 없다..
그럼 뭔가 쪼개서 구해야 할 듯 싶다는 생각이 든다.
[9, 14]의 합은 Array[9] + Array[5] + Array[6] +Array[14]를 구해보면 된다.
Array[5]는 Array[10] + Array[11]이고, Array[6]은 Array[12] + Array[13]이니까.
문제는 구현인데, 코드를 살펴보기 앞서 답을 구해가는 과정에 대해 먼저 알아보자.
-> 시작은 루트인 Array[1]이다. Array[1]은 [8, 15]의 값을 모두 담고 있으니, [2]와 [3]으로 쪼개자.
-> Array[2] = Array[8, 11]이다. 결과에 필요없는 부분이 있으니 [4]와 [5]로 쪼개자.
-> Array[4] = Array[8, 9]이므로 필요없는 부분이 있으니 [8]과 [9]로 쪼개자.
-> Array[8]은 필요없다. 버리자.
-> Array[9]는 [9, 14]에 포함된다. 결과에 더하자.
-> Array[5] = Array[10, 11]이므로 [9, 14]에 포함된다. 결과에 더하자.
-> Array[3] = Array[12, 15]이다. 결과에 필요없는 부분이 있으니 [6]과 [7]로 쪼개자.
-> Array[6] = Array[12, 13]이므로 [9, 14]에 포함된다. 결과에 더하자.
-> Array[7] = Array[14, 15]이므로 필요없는 부분이 있으니 [14]와 [15]로 쪼개자.
-> Array[14]는 [9, 14]에 포함된다. 결과에 더하자.
-> Array[15]는 필요없다. 버리자.
.... 쓰고보니 뭔가 알아보기 매우 힘들다..
들여쓰기가 늘어남에 따라 자식으로 한번 더 내려간다고 생각하면 될 것 같다. 답을 찾아가는 과정을 딱 보아하니, 재귀적으로 구현하면 될 것 같다는 생각이 든다.
핵심적인 부분은,
구하고자 하는 구간 [left, right]에 지금 도착한 노드의 담당 구간 [start, end]가 포함되는지 아닌지를 체크하면서 자식으로 갈지 멈출지를 결정한다는 것이다.
즉, left <= start 이고 end <= right 이면 그 구간은 모두 필요하므로 답에 더해주고,
left > end 이거나 right < start 이면 [start, end]는 모두 필요 없으므로 더이상 진행하지 않는다.
모두 필요없는 구간은 밑 그림을 참고해보자.
그 외에는 걸쳐있는 부분이니 한번 더 자식으로 들어간다.
걸쳐있는 부분이란,
위 그림과 같이 현재 노드의 담당 구간 [Start, End]에 대하여 필요한 구간 [Left, Right]가 걸쳐서 형성되어 있는 경우를 말한다.
사실 위에서 설명했던 답에 포함되는 구간(모두 필요한 구간)과 모두 필요없는 구간(left > end 이거나 right < start)이 아닌 나머지는 전부 걸쳐있다고 생각하면 되겠다.
이제 구현만 남았다.
int Find(int start, int end, int idx, int left, int right)
{
if (left > end || right < start) // 필요없는 구간
return 0;
else if (left <= start && end <= right) // 모두 필요한 구간
return Array[idx];
else // 그 외 : 걸쳐져 있는 구간
return Find(start, (start + end) / 2, idx * 2, left, right)
+ Find((start + end) / 2 + 1, idx * 2 + 1, left, right);
}
걸쳐져 있는 구간(else 문)에서 start, end에 대한 계산은 위쪽에 있는 트리 그림을 참고하며 보면 이해하기 쉬울 것이다.
구간을 반으로 쪼개며 자식으로 들어가니까 [start, end]는 [start, (start + end) / 2]와 [(start + end) / 2 + 1, end] 두 개로 나뉘어 진다고 생각하면 된다.
idx는 현재 트리의 노드의 index이고, [left, right]는 답을 구하고자 하는 구간이다.
마지막으로 시간복잡도는 함수가 호출되는 시점마다 이진트리의 한 level을 건너가므로, O(logN)이라 할 수 있겠다.
좀더 정확하게 말하자면 N=2^k인 경우가 꽉찬 이진트리가 생성이 되고, 매 호출마다 분할되어 다음 자식으로 내려가기 때문에 마치 병합 정렬(Merge Sort)의 분할과 같은 메커니즘으로 동작한다고 볼 수 있겠다.
결국 Update, Find 모두 O(logN)만에 수행이 되므로 .. 쿼리 M개에 대해서 O(MlogN)만에 해결할 수 있는 방법을 알게 된 것이다!
이제 9345번 문제만 남았다..
3. [백준/BOJ 9345번] 디지털 비디오 디스크(DVDs)
이렇게 우선 수들의 합(2268번) 문제를 기준으로 세그먼트 트리에 대한 설명을 해 보았다.
벌써 이렇게나 길어졌다니...
마지막으로 9345번 디지털 비디오 디스크(DVDs) 문제에 대한 설명을 시작하겠다.
문제가 매우 길다.. 간단한 조건만 간추려 보자면,
L번 선반부터 R번 선반까지에 있는 DVD들을 가져 왔을때 실제로 DVD가 L번부터 R번까지 있나 확인을 해주어야 하며,
[L, R] 구간에는 순서에 상관없이 L번부터 R번까지 DVD만 있으면 된다.
그러니까, [2, 5]의 구간(선반)에 있는 DVD를 빌릴려고 하는데, 실제로 {2, 3, 4, 5}가 순서에 상관없이 Array[2, 5]에만 들어가 있으면 된다는 말 같다.
한 가지 재미있는 사실은, 원소들의 순서가 {2, 3, 4, 5}든 {4, 5, 3, 2}든 {3, 5, 2, 4}든 구간 [2, 5]의 최대값과 최소값은 항상 2와 5로 고정되어있다는 점이다.
오! 그렇다면 DVD가 섞여도 구간의 최대, 최소만 계속 기억해둔다면 구간에 필요한 원소들이 모두 들어가있는지 확인할 수 있을 것이다.
또한 테스트케이스 T가 20이하이고 조건이 N이 10만까지이며, 일어나는 사건의 수 K가 5만까지이므로 O(TNK) 안에는 해결할 수 없을 것이다.
따라서 최대, 최소를 저장하는 세그먼트 트리를 각각 만들어서 문제를 해결하면 O(TKlogN)으로 해결할 수 있을 것이다.
N이 10만이니까, 배열 크기는 (262144 - 1)는 되어야 하지만 역시 귀찮으므로..
int minTree[300000], maxTree[300000];
int N, K, T;
다음 트리의 초기화를 진행해보자.
int initTree(int n)
{
// 매 Testcase마다 배열을 사용하므로 초기화
memset(minTree, 0, sizeof(minTree));
memset(maxTree, 0, sizeof(maxTree));
int bound = 1;
while (bound < n)
bound *= 2;
for (int i = bound; i < bound + N; i++)
maxTree[i] = minTree[i] = i - bound;
for (int i = bound + N; i < 2 * bound; i++)
minTree[i] = N; // 최소값 트리는 영향을 주지 않는 수(>= N)로 초기화
for (int i = bound - 1; i > 0; i--) // 최대, 최소 트리 초기화
{
maxTree[i] = max(maxTree[2 * i], maxTree[2 * i + 1]);
minTree[i] = min(minTree[2 * i], minTree[2 * i + 1]);
}
return bound; // 리프 노드의 시작인 bound를 반환
}
이제 코드 이해는 쉬울 것이다. bound를 반환하는 이유는 그냥 main 함수에 초기화 부분을 쓰면 길어지니까 따로 함수로 빼두고, bound는 써야하니까 반환했다.. 내 취향이니 bound는 전역변수로 써도 무방하고 어떻게 써도 상관없다.
다음은 Update함수를 보자.
void updateTree(int a, int b)
{
// a, b에 있는 DVD 위치를 서로 바꿈
swap(minTree[a], minTree[b]);
swap(maxTree[a], maxTree[b]);
int idx = a / 2;
while (idx > 0) // Tree[a]부터 출발하여 부모들의 값을 갱신
{
minTree[idx] = min(minTree[idx * 2], minTree[idx * 2 + 1]);
maxTree[idx] = max(maxTree[idx * 2], maxTree[idx * 2 + 1]);
idx /= 2;
}
idx = b / 2;
while (idx > 0) // Tree[b]부터 출발하여 부모들의 값을 갱신
{
minTree[idx] = min(minTree[idx * 2], minTree[idx * 2 + 1]);
maxTree[idx] = max(maxTree[idx * 2], maxTree[idx * 2 + 1]);
idx /= 2;
}
}
역시 어려운 부분이 없다.
두번 갱신하는 이유는 DVD를 서로 맞교환하므로 Tree[a]와 Tree[b]를 swap한 후, 각 두 개의 위치에서 갱신을 수행해주면 된다.
다음은 Find!
int findMax(int start, int end, int idx, int left, int right)
{
if (left > end || right < start)
return -1; // 최대값이므로 0보다 작은 수 반환
else if (start >= left && end <= right)
return maxTree[idx];
else
return max(findMax(start, (start + end) / 2, idx * 2, left, right),
findMax((start + end) / 2 + 1, end, idx * 2 + 1, left, right));
}
int findMin(int start, int end, int idx, int left, int right)
{
if (end < left || start > right)
return N; // 최소값이므로 답에 영향을 주지않는 값(>=N)으로 반환
else if (start >= left && end <= right)
return minTree[idx];
else
return min(findMin(start, (start + end) / 2, idx * 2, left, right),
findMin((start + end) / 2 + 1, end, idx * 2 + 1, left, right));
}
역시 수들의 합 문제를 설명할 때 보였던 부분과 달라진 바가 없다. 최대, 최소에 따라 반환하는 값을 잘 조정해주면 된다.
마지막으로, [left, right]의 구간의 최소, 최대가 각각 left, right와 같은지만 확인하면 된다.
bool isPossible(int left, int right, int bound)
{
int Max = findMax(bound, 2 * bound - 1, 1, left, right);
int Min = findMin(bound, 2 * bound - 1, 1, left, right);
if (left - bound == Min && right - bound == Max)
return true;
return false;
}
find함수의 start, end 호출 인자로는 Tree[idx]가 리프노드 전체를 담당하므로 리프노드의 시작인 bound부터 2 * bound - 1까지의 범위를 대입해주면 된다.
(+ 참고)
이 문제처럼 입출력이 굉장히 잦은 (사실 기준은 잘 모르겠다..) 문제들은 간혹 cin/cout으로 표준입출력을 수행하면 제출 시 시간 초과가 나는 경우가 있다.
이 문제도 그러하다. 따라서 입출력이 잦은 경우 가급적 <stdio.h>의 표준입출력을 사용하는 것이 좋을 듯 하다.
<전체 소스>
#include <stdio.h>
#include <iostream>
#include <algorithm>
#include <memory.h>
using namespace std;
int minTree[300000], maxTree[300000];
int N, K, T;
int initTree(int n)
{
// 매 Testcase마다 배열을 사용하므로 초기화
memset(minTree, 0, sizeof(minTree));
memset(maxTree, 0, sizeof(maxTree));
int bound = 1;
while (bound < n)
bound *= 2;
for (int i = bound; i < bound + N; i++)
maxTree[i] = minTree[i] = i - bound;
for (int i = bound + N; i < 2 * bound; i++)
minTree[i] = N; // 최소값 트리는 영향을 주지 않는 수(>= N)로 초기화
for (int i = bound - 1; i > 0; i--) // 최대, 최소 트리 초기화
{
maxTree[i] = max(maxTree[2 * i], maxTree[2 * i + 1]);
minTree[i] = min(minTree[2 * i], minTree[2 * i + 1]);
}
return bound; // 리프 노드의 시작인 bound를 반환
}
void updateTree(int a, int b)
{
// a, b에 있는 DVD 위치를 서로 바꿈
swap(minTree[a], minTree[b]);
swap(maxTree[a], maxTree[b]);
int idx = a / 2;
while (idx > 0) // Tree[a]부터 출발하여 부모들의 값을 갱신
{
minTree[idx] = min(minTree[idx * 2], minTree[idx * 2 + 1]);
maxTree[idx] = max(maxTree[idx * 2], maxTree[idx * 2 + 1]);
idx /= 2;
}
idx = b / 2;
while (idx > 0) // Tree[b]부터 출발하여 부모들의 값을 갱신
{
minTree[idx] = min(minTree[idx * 2], minTree[idx * 2 + 1]);
maxTree[idx] = max(maxTree[idx * 2], maxTree[idx * 2 + 1]);
idx /= 2;
}
}
int findMax(int start, int end, int idx, int left, int right)
{
if (left > end || right < start)
return -1; // 최대값이므로 0보다 작은 수 반환
else if (start >= left && end <= right)
return maxTree[idx];
else
return max(findMax(start, (start + end) / 2, idx * 2, left, right),
findMax((start + end) / 2 + 1, end, idx * 2 + 1, left, right));
}
int findMin(int start, int end, int idx, int left, int right)
{
if (end < left || start > right)
return N; // 최소값이므로 답에 영향을 주지않는 값(>=N)으로 반환
else if (start >= left && end <= right)
return minTree[idx];
else
return min(findMin(start, (start + end) / 2, idx * 2, left, right),
findMin((start + end) / 2 + 1, end, idx * 2 + 1, left, right));
}
bool isPossible(int left, int right, int bound)
{
int Max = findMax(bound, 2 * bound - 1, 1, left, right);
int Min = findMin(bound, 2 * bound - 1, 1, left, right);
if (left - bound == Min && right - bound == Max)
return true;
return false;
}
int main(void)
{
scanf("%d", &T);
for (int t = 0; t < T; t++)
{
scanf("%d %d", &N, &K);
int bound = initTree(N);
for (int k = 0; k < K; k++)
{
int Q, A, B;
scanf("%d %d %d", &Q, &A, &B);
if (Q == 0)
updateTree(bound + A, bound + B);
else
{
if (isPossible(bound + A, bound + B, bound))
printf("YES\n");
else
printf("NO\n");
}
}
}
}