k-d Tree Algorithm

Introduction

Build a k-d tree would reduce the time significantly when it works with low dimensional data. However, due to the curse of dimensionality, it works no better than exhaustive search with high dimensional data. In this case, most of the nodes will be visited and it takes time to build the k-d tree. Hence, approximate nearest neighbour methods are proposed in this situation.

Tree node

# tree node structure
class Node:
    def __init__(self, left, right, depth, point, val, distance):
        self.left = left
        self.right = right
        self.depth = depth
        self.point = point
        self.val = val
        self.distance = distance

“val” is the median value of a dimension

Build k-d tree

# build k-d tree
def buildKdTree(data,D = 0):
    #if it is a leaf node
    if len(data)==0:
        return None
    else:
        #total number of dimensions
        d = D % nmb_dimension

        #sort the points according to the dimension
        sorted_array = data[data[:,d].argsort()]

        #get the median value along d dimension
        val = sorted_array[sorted_array.shape[0]//2,d]

        node = Node(None, None, None, None, None,-1)
        node.depth = d
        node.point = sorted_array[sorted_array.shape[0]//2]
        node.val = val

        node.left = buildKdTree(sorted_array[:sorted_array.shape[0]//2], D+1)  
        node.right = buildKdTree(sorted_array[sorted_array.shape[0]//2+1:],D+1)
    return node

Search tree

#calculate the Euclidean distance between two points
def distance(point1, point2):
    return sum((point1-point2) ** 2) ** 0.5

# K-d tree search
def searchKdTree(k,here,test_sample,best):
    if here is None:
        return best

    here.distance = distance(here.point,test_sample)

    if best == []:
        #add the node to the best list
        best.append(here)
    #compare and update the best nodes if it is necessary
    for i in range(len(best)):
        #if the distance between y and here is less than anyone in the bests
        if here.distance < best[i].distance :
            if(len(best) < k):
                best.append(here)
                break
            else:
                best.append(here)
                #pop the node which has the longest distance
                longest_distance = best[0].distance
                index = 0
                for j in range (len(best)):
                    if best[j].distance > longest_distance:
                        longest_distance = best[j].distance
                        index = j
                best.pop(index)
                break

    if test_sample[here.depth] < here.val:
        child_near = here.left
        child_far = here.right
    else:
        child_near = here.right
        child_far = here.left

    best = searchKdTree(k,child_near,test_sample,best)
    #find the lower bound
    longest_distance = best[0].distance
    for j in range (len(best)):
        if best[j].distance > longest_distance:
                longest_distance = best[j].distance

    #if the lower bound does not exceed the smallest distance
    if abs(test_sample[here.depth] - here.val) < longest_distance:
        best = searchKdTree(k,child_far,test_sample,best)
    return best