k-d트리는 다차원의 점들(자료들)을 저장하기 위한 자료구조입니다. 알아보기 전에 먼저 저보다 훨씬 잘 설명해둔 위키피디아 페이지를 소개하겠습니다. (링크)

그냥 여기만으로 이해하는 게 훨씬 편할 수도 있습니다.

개요

설명하자면 k-d 트리는 모든 노드가 k차원 점인 이진 트리입니다. 모든 리프 노드는 공간을 반평면의 두 부분으로 나누는 분할 초평면을 만드는 것으로 생각할 수 있습니다. 이 초평면의 왼쪽은 그 노드의 왼쪽 부분 트리를 나타내고 오른쪽은 오른쪽 부분 트리를 나타냅니다. 초평면은 k차원 중 하나의 차원축에 수직적이어야 합니다. 그러니 예를 들어 x=3을 기준으로 x<3인 점을 왼쪽 서브트리에, x>3인 점을 오른쪽 서브트리에 둘 수 있겠죠.

작동 방식

그럼 이제부터 k-d 트리에서 어떻게 점을 저장하는지 자세히 알아보도록 합시다.

k차원의 점들이 N개 존재한다 합시다. 먼저 우리는 N개의 점들을 X좌표를 기준으로 나누어주어야 합니다. 적절하게 나누기 위해서는 그 점들의 x좌표의 중앙값을 기준으로 나누는 게 제일 적절하겠죠. 중앙값은 그냥 점들 전체를 정렬하면 O(NlogN)에 구할 수 있습니다. 더 빠르게도 구할 수 있다고 하는데, 그건 좀 논점 밖이죠.

그리고 그 노드에는 그 x좌표의 중앙값에 해당하는 점을 배치하고 그 점보다 x좌표가 작은 점들은 왼쪽, 그 점보다 x좌표가 큰 점들은 오른쪽으로 갑니다.

그리고 이제 다시 점들을 나눠줘야 되는데, 또 x좌표로 정렬하는 걸까요? 아닙니다. 그러면 x좌표가 모두 같을 때는 트리가 너무 비효율적인 구조로 만들어집니다. 대신 우리는 y좌표를 기준으로 하여 점들을 정렬해줍니다.

왼쪽에서도 y좌표를 기준으로 정렬해서 중앙값 기준으로 나눠주고, 오른쪽에서도 y좌표 기준으로 정렬해서 중앙값 기준으로 나눠줍니다.

그리고 또 점들은 더 밑으로 내려오겠죠. 그럼 그 점들은 이제 z좌표를 기준으로 나눠주면 됩니다! 물론 3차원 이상의 점이 아니라면 다시 x좌표 기준으로 나눠줍니다.

정리하자면, 점들의 차원이 k차원이고 현재 depth가 x일 때, (x%k) 번째 차원을 기준으로 점들을 나눠줍니다. (깊이와 차원은 모두 0-based)

예시

2차원에서 k-d tree를 나타낸 그림을 보여드리겠습니다.

2차원에서의 k-d tree 저 그림에서 어떻게 k-d tree가 구축되는지 알아봅시다. 먼저 모든 점들을 봅시다. (2,3),(4,7),(5,4),(8,1),(9,6),(7,2)가 있습니다. x좌표 순으로 정렬하면 (2,3),(4,7),(5,4),(7,2),(8,1),(9,6)입니다. 자료가 짝수 개인데 중앙값은 평균으로 뽑아야 될까요? 아니요. 그냥 3번째나 4번째 중에 맘에 드는 거 뽑습니다. 저는 4번째인 (7,2)를 뽑겠습니다.

그리고 깊이를 한 단계 내려갑시다. 왼쪽으로는 점 (2,3),(4,7),(5,4)이 내려옵니다. y좌표 기준으로 정렬했을 때 중간은 (5,4)이고 y좌표가 작은 (2,3)이 왼쪽, y좌표가 큰 (4,7)이 오른쪽으로 갑니다.

오른쪽에는 (8,1),(9,6)이 오는데 점이 두 개밖에 없으니 아무 거나 하나는 위에 두고 하나는 밑에 두면 됩니다. (9,6)을 위에 두고 y좌표가 작은 (8,1)은 왼쪽에 두면 됩니다.

시간 복잡도와 전체 코드

그럼 k-d tree를 만드는 데 걸리는 시간복잡도는 어떻게 될까요? 각 노드에서 중앙값을 구하는 데에는 그 노드의 서브트리에 속한 점들의 수를 N이라 할 때 O(NlogN)의 시간이 걸립니다. 한 높이에 N개의 점이 존재하고 높이는 최대 O(logN)이므로 한 높이에서 O(NlogN)*최대 높이 O(logN)을 곱하면 O(Nlog^2N)입니다.

2차원에서의 k-d트리를 구현한 코드입니다. 세그먼트 트리나 힙에서도 쓰이는 아이디어를 이용해서 루트 노드의 번호를 1로 하고, 한 노드의 번호를 k라 했을 때 2k가 왼쪽 자식의 번호, 2k+1이 오른쪽 자식의 번호가 되도록 하였습니다.


#include <bits/stdc++.h>
using namespace std;

typedef pair<long long,long long> P;

struct Node {
	P pt;
	bool st; //standard. false면 x좌표 비교, true이면 y좌표 비교
	long long sx,ex,sy,ey; //x좌표 시작,끝. y좌표 시작,끝.
};

Node kdtree[513040];
int n;
bool isexist[513040]; //해당 인덱스의 노드가 kdtree에 존재하는가
P point[142857];
long long ret; //현재까지 찾은 최소거리. 재귀적으로 들어가면서도 잘 작동해야 하니 전역변수로 둠.
P save[142857];

long long square(long long a) {
	return a*a;
}

bool xcomp(P a,P b) { //x좌표 기준 비교
	if (a.first==b.first)
	return a.second<b.second;
	return a.first<b.first;
}

bool ycomp(P a,P b) { //y좌표 기준 비교
	if (a.second==b.second)
	return a.first<b.first;
	return a.second<b.second;
}

void maketree(int l,int r,int ind) {
	long long minx=2e9;
	long long maxx=-2e9;
	long long miny=2e9;
	long long maxy=-2e9;
	for(int i=l;i<=r;i++) {
		minx=min(minx,point[i].first);
		maxx=max(maxx,point[i].first);
		miny=min(miny,point[i].second);
		maxy=max(maxy,point[i].second);
	}
	kdtree[ind].sx=minx;
	kdtree[ind].ex=maxx;
	kdtree[ind].sy=miny;
	kdtree[ind].ey=maxy;
	isexist[ind]=true;
	if (ind==1) {
		kdtree[ind].st=false; //루트노드는 x좌표를 비교하게 설정
	}
	else {
		kdtree[ind].st=!kdtree[ind/2].st; //부모노드의 반대로 비교
	}
	if (kdtree[ind].st) {
		sort(point+l,point+r+1,ycomp); //point배열 정렬
	}
	else {
		sort(point+l,point+r+1,xcomp);
	}
	int mid=(l+r)/2;
	kdtree[ind].pt=point[mid];  //중앙의 점을 그 kdtree의 점으로 저장
	if (l<=mid-1)    //재귀적으로 왼쪽 서브트리 만들기
	    maketree(l,mid-1,ind*2);
	if (mid+1<=r)    //재귀적으로 오른쪽 서브트리 만들기
	    maketree(mid+1,r,ind*2+1);
}

활용

그러면 이제 k-d tree를 어떻게 활용하죠? 사실 ps에서 활용하는 방법이 많지는 않습니다.

하지만 CERC 2008에서 딱 이것을 이용할 수 있는 문제가 나왔습니다! (링크)
티어가 많이 높네요… 역시 이상한 자료구조에요.

여러 개의 점들이 주어질 때, 각각의 점에서 가장 가까운 점을 찾을 때 쓸 수 있습니다!
트리를 재귀적으로 내려가면서 그 점이 있는 방향으로는 무조건 내려가고, 그 점이 없는 방향에는 그 점과의 x(y좌표)의 최소 차이가 현재 최솟값보다 작다면 내려가주면 됩니다.

코드는 이렇게 됩니다.


#include <bits/stdc++.h>
using namespace std;

typedef pair<long long,long long> P;

struct Node {
	P pt;
	bool st; //standard. false면 x좌표 비교, true이면 y좌표 비교
	long long sx,ex,sy,ey; //x좌표 시작,끝. y좌표 시작,끝.
};

Node kdtree[513040];
int n;
bool isexist[513040]; //해당 인덱스의 노드가 kdtree에 존재하는가
P point[142857];
long long ret; //현재까지 찾은 최소거리. 재귀적으로 들어가면서도 잘 작동해야 하니 전역변수로 둠.
P save[142857];

long long square(long long a) {
	return a*a;
}

bool xcomp(P a,P b) { //x좌표 기준 비교
	if (a.first==b.first)
	return a.second<b.second;
	return a.first<b.first;
}

bool ycomp(P a,P b) { //y좌표 기준 비교
	if (a.second==b.second)
	return a.first<b.first;
	return a.second<b.second;
}

void maketree(int l,int r,int ind) {
	long long minx=2e9;
	long long maxx=-2e9;
	long long miny=2e9;
	long long maxy=-2e9;
	for(int i=l;i<=r;i++) {
		minx=min(minx,point[i].first);
		maxx=max(maxx,point[i].first);
		miny=min(miny,point[i].second);
		maxy=max(maxy,point[i].second);
	}
	kdtree[ind].sx=minx;
	kdtree[ind].ex=maxx;
	kdtree[ind].sy=miny;
	kdtree[ind].ey=maxy;
	isexist[ind]=true;
	if (ind==1) {
		kdtree[ind].st=false; //루트노드는 x좌표를 비교하게 설정
	}
	else {
		kdtree[ind].st=!kdtree[ind/2].st; //부모노드의 반대로 비교
	}
	if (kdtree[ind].st) {
		sort(point+l,point+r+1,ycomp); //point배열 정렬
	}
	else {
		sort(point+l,point+r+1,xcomp);
	}
	int mid=(l+r)/2;
	kdtree[ind].pt=point[mid];  //중앙의 점을 그 kdtree의 점으로 저장
	if (l<=mid-1)    //재귀적으로 왼쪽 서브트리 만들기
	    maketree(l,mid-1,ind*2);
	if (mid+1<=r)    //재귀적으로 오른쪽 서브트리 만들기
	    maketree(mid+1,r,ind*2+1);
}

void getans(int ind,P now) {
	if (now!=kdtree[ind].pt) {
		ret=min(ret,square(kdtree[ind].pt.first-now.first)+square(kdtree[ind].pt.second-now.second));
	}
	// 일단 알아둬야 하는 게 존재하지 않는 노드로는 들어가서는 안 됨
	if (kdtree[ind].st) { //y좌표 기준 비교하는 깊이일 때
		if (now.second<kdtree[ind].pt.second) {
			if (isexist[ind*2]) { //자신의 y좌표가 그거보다 작으니 작은 쪽으로는 바로 탐색
				getans(ind*2,now);
			}
			if (isexist[ind*2+1]&&square(kdtree[ind*2+1].sy-now.second)<ret) {//최소한의 y좌표 차이가 현재까지의 최솟값보다 작을 때 탐색
                getans(ind*2+1,now);
			}
		}
		else { //그냥 반대
			if (isexist[ind*2+1]) {
				getans(ind*2+1,now);
			}
			if (isexist[ind*2]&&square(kdtree[ind*2].ey-now.second)<ret) {
				getans(ind*2,now);
			}
		}
	}
	else { //x좌표 기준으로 바꿔서 처리
		if (now.first<kdtree[ind].pt.first) {
			if (isexist[ind*2]) {
				getans(ind*2,now);
			}
			if (isexist[ind*2+1]&&square(kdtree[ind*2+1].sx-now.first)<ret) {
				getans(ind*2+1,now);
			}
		}
		else {
			if (isexist[ind*2+1]) {
				getans(ind*2+1,now);
			}
			if (isexist[ind*2]&&square(kdtree[ind*2].ex-now.first)<ret) {
				getans(ind*2,now);
			}
		}
	}
}

int main(void) {
	int tc;
	scanf("%d\n",&tc);
	for(int t=0;t<tc;t++) {
		memset(isexist,0,sizeof(isexist)); //배열 초기화
		int n;
		scanf("%d\n",&n);
        vector<P> temp;
		for(int i=0;i<n;i++) {
			int x,y;
			scanf("%d %d\n",&x,&y);
			temp.push_back({x,y});
			save[i]=temp[i];
		}
		sort(temp.begin(),temp.end(),xcomp);
        vector<P> two;
		point[0]=temp[0];
		int cnt=1;
		for(int i=1;i<temp.size();i++) {
		    if (temp[i-1]!=temp[i]) {
		        point[cnt++]=temp[i];
		    }
            else {
                two.push_back(temp[i]);
            }
		}
		maketree(0,cnt-1,1);
		for(int i=0;i<n;i++) {
            if (!two.empty()&&*lower_bound(two.begin(),two.end(),save[i])==save[i]) {
                printf("0\n");
            }
            else {
			    ret=5e18; //최악의 값 그 이상으로 초기화
			    getans(1,save[i]);
			    printf("%lld\n",ret);
            }
		}
	}
}

's profile image

2020-02-09 03:00

Read more posts by this author