k-d Tree Algorithm
Introduction
Build a k-d tree would reduce the time significantly when it works with low dimensional data. However, due to the curse of dimensionality, it works no better than exhaustive search with high dimensional data. In this case, most of the nodes will be visited and it takes time to build the k-d tree. Hence, approximate nearest neighbour methods are proposed in this situation.
Tree node
# tree node structure
class Node:
def __init__(self, left, right, depth, point, val, distance):
self.left = left
self.right = right
self.depth = depth
self.point = point
self.val = val
self.distance = distance
“val” is the median value of a dimension
Build k-d tree
# build k-d tree
def buildKdTree(data,D = 0):
#if it is a leaf node
if len(data)==0:
return None
else:
#total number of dimensions
d = D % nmb_dimension
#sort the points according to the dimension
sorted_array = data[data[:,d].argsort()]
#get the median value along d dimension
val = sorted_array[sorted_array.shape[0]//2,d]
node = Node(None, None, None, None, None,-1)
node.depth = d
node.point = sorted_array[sorted_array.shape[0]//2]
node.val = val
node.left = buildKdTree(sorted_array[:sorted_array.shape[0]//2], D+1)
node.right = buildKdTree(sorted_array[sorted_array.shape[0]//2+1:],D+1)
return node
Search tree
#calculate the Euclidean distance between two points
def distance(point1, point2):
return sum((point1-point2) ** 2) ** 0.5
# K-d tree search
def searchKdTree(k,here,test_sample,best):
if here is None:
return best
here.distance = distance(here.point,test_sample)
if best == []:
#add the node to the best list
best.append(here)
#compare and update the best nodes if it is necessary
for i in range(len(best)):
#if the distance between y and here is less than anyone in the bests
if here.distance < best[i].distance :
if(len(best) < k):
best.append(here)
break
else:
best.append(here)
#pop the node which has the longest distance
longest_distance = best[0].distance
index = 0
for j in range (len(best)):
if best[j].distance > longest_distance:
longest_distance = best[j].distance
index = j
best.pop(index)
break
if test_sample[here.depth] < here.val:
child_near = here.left
child_far = here.right
else:
child_near = here.right
child_far = here.left
best = searchKdTree(k,child_near,test_sample,best)
#find the lower bound
longest_distance = best[0].distance
for j in range (len(best)):
if best[j].distance > longest_distance:
longest_distance = best[j].distance
#if the lower bound does not exceed the smallest distance
if abs(test_sample[here.depth] - here.val) < longest_distance:
best = searchKdTree(k,child_far,test_sample,best)
return best