Stokastik

Machine Learning, AI and Programming

Using KD-Tree For Nearest Neighbor Search

This post is branched from my earlier posts on designing a question-question similarity system. In the first of those posts, I discussed the importance of speed of retrieval of most similar questions from the training data, given a question asked by a user in an online system. We designed few strategies, such as the HashMap based retrieval mechanism. The HashMap based retrieval assumes that at-least one word between the most similar questions and the asked question is same.

In most practical situations, 95% of the time, this will suffice, but there could be 5% cases where either the question being asked has completely new vocabulary (more probable) or none of the words in the asked question and the most similar questions are same (less probable).

Note : "similar" doesn't mean that a pair of questions has to have at-least one word common.

For e.g. the pair of questions :

"What is the minimum age to buy a gun ?" and "How old do I need to be to obtain a weapon ?".

Ignoring the stop-words, none of the other words are same in both the questions.

In this post, we look at an alternative to linear scanning (note that, we do not scan all questions, we scan only the questions represented by the cluster heads). It's quite obvious that if there are N questions of feature size D, then a linear scan over them would take O(N*D) time and space complexity.

The idea of KD-Tree is to construct a tree out of the questions, in such a way that will reduce the time complexity of searching an exact match or nearest matching questions.

For e.g., given N points in a 1-D space, then if the points are sorted, searching for a point (if it exists or not) takes O(logN) time complexity using plain binary search. Searching for the nearest point to a reference point is also O(logN) as we show below.

A python implementation to find the nearest element given a reference point (assuming points are sorted in 1-D space) :

def get_nearest(arr, ref):
    left, right = 0, len(arr) - 1

    while left <= right:
        mid = int((left + right) / 2)

        if arr[mid] > ref:
            right = mid - 1
        else:
            left = mid + 1

    if left > 0 and abs(arr[left - 1] - ref) < abs(arr[left] - ref):
        return arr[left - 1]
    else:
        return arr[left]

The code searches for the interval in the array, in which our reference point lies, and then returns either the left or the right side value of the interval depending on which one is more closer to our reference point.

The case for 1-D was quite simple. Now consider points in 2-D. The idea is to first divide the points in two parts along the x-axis (points which lie to the left and points which lie to the right of x-axis). Then for each left and right side, split the points further into two parts but this time along the y-axis. Thus now each left and right parts are further split into up and down sub-parts. Repeat the process in round-robin fashion until no further splitting is possible.

KD Tree splitting in 2 dimensions (Source : research-gate)

The split-point is chosen to be the point having the median value along the split-axis.

Once the tree is constructed, searching for an exact point in the tree takes O(D*log(N)) time complexity. Why ?

If the root is the reference point, then we are done (in O(D) time), else check with the value along the split axis for the root, i.e. the x-axis. If the value along the x-axis of the reference point is less than equal to the split value at the root then search in left sub-tree else search in the right sub-tree.

For e.g. to search for (60, 80) in the below tree, one would compare 60 with 51 (x-axis value of root) and then decide that 60 > 51 and we should go to right sub-tree. Then compare 80 with 70 (the y-axis value of the root at right sub-tree) and decide that 80 > 70 and thus again go right and so on.

KD-Tree constructed out of 2D points.

Note that by round-robin splitting we are trying to ensure that the tree is balanced as much as possible. Thus with N points, the maximum height is O(logN). Since at each level of the tree we search at-most one point and each point is D (=2) dimensional, thus the time complexity for exact search is O(D*log(N)). But round-robin splitting may not always be optimal (what if all the points or most points lie along the x-axis ?)

Searching the nearest point is a bit tricky in higher than 1-D space.

Let's traverse the tree as we would do in the case of exact search. For example, if our reference point is (50, 2) (shown in green below), then by traversing the tree we would land up at the bounding box with (10, 30) on the left boundary and (25, 40) on the top boundary. Now is it true that either of (10, 30) or (25, 40) is the closest point to (50, 2) ? No, because the closes point is (55, 1) which is inside the bounding box to the right. That means we simply cannot traverse the tree in one of left or right direction in each step to get to the solution as in the case of exact search.

Bounding Box

To solve this problem, we need to know, whether we must search only one side of a sub-tree or both the sides. But how to know that ?

Since we reached our destination by traversing last through (10, 30), we compute the (euclidean) distance of (50, 2) from (10, 30), which is 48.83. This distance becomes the radius of a sphere centered at our ref.point. Now, if this sphere cuts through the plane of the split axis (x=10), then we also need to scan the left side of the plane because some region on the left side of the plane can come inside this sphere. Which is true since the perpendicular distance of the ref. point from the plane x=10 is 40 < 48.83, implying that this sphere cuts through the x=10 plane.

Fortunately (1, 10) is at a distance of 49.65, which is greater than 48.83 and thus our best guess still remains (10, 30).

Once we have looked into all possible branches of a sub-tree, we start to backtrack up the tree. (25, 40) is split on the y-axis (y=40). Since the perpendicular distance from y=40 to (50, 2) is 38 < 48.83 (the radius of the smallest sphere seen so far), it implies that the sphere cuts through this plane too and we need to scan the other side of y=40 also. Repeating the same for the right sub-tree of (25, 40), we find that the point (50, 50) is closer i.e. distance=48 < 48.83. Thus our new solution becomes (50, 50).

Once we are done checking the left and right sub-tree of (25, 40), we check the distance of the ref. point with the root (25, 40) itself. And it is 45.49 < 48. Thus (25, 40) is our current best solution.

Next up (51, 75) is split on the x-axis by the x=51 plane. Perpendicular distance of (50, 2) from x=51 plane is 1 which is less than 45.49 (radius of smallest sphere seen till now). Thus we need to scan the right sub-tree rooted at (51, 75). Repeat the above process and we eventually find the true closest point (55, 1). Note that by the time we discover our solution (55, 1), we have scanned each and every point in the tree. Thus nearest neighbor search in a KD-Tree has the worst case complexity of O(N*D), similar to linear scan.

If the reference point had been say (12, 33), then we need not have gone towards the right sub-tree of (25, 40) or right sub-tree of (51, 75), because the radius of sphere from (10, 30) is 3.6, which is less than the perpendicular distances of (12, 33) from y=40 (7) and x=51 (39).

Here is a small animation taken from wikipedia explaining how we go about searching the tree :

Nearest Neighbor Search

Time to code up !!!

Instead of going by the traditional approach of using recursive functions to build and call the tree (as defined in many articles and tutorials), we will take an iterative approach, because recursion will throw error when the depth of the tree increases beyond maximum recursion depth. Although the depth increases by order of O(logN) only, still we did not want to take any chances.

For iterative solution, we will use a Queue based approach, that builds the tree level-wise.

import numpy as np
import time, math, heapq
from collections import deque

"""
Class for defining all variables that goes into the queue while
constructing the KD Tree
"""
class QueueObj(object):
    def __init__(self, indices, depth, node, left, right):
        self.indices, self.depth, self.node = indices, depth, node
        self.left, self.right = left, right

"""
Class for defining the node properties for the KD Tree
"""
class Node(object):
    def __init__(self, vector, split_value, split_row_index):
        self.vector, self.split_value, self.split_row_index = vector, split_value, split_row_index
        self.left, self.right = None, None

"""
KD Tree class starts here
"""
class KDTree(object):
    def __init__(self, vectors):
        self.vectors = vectors
        self.root = None
        self.vector_dim = vectors.shape[1]

    def construct(self):
        n = self.vectors.shape[0]

        queue = deque([QueueObj(range(n), 0, None, 0, 0)])

        while len(queue) > 0:
            qob = queue.popleft()
            q_front, depth, parent, l, r = qob.indices, qob.depth, qob.node, qob.left, qob.right

            axis = depth % self.vector_dim

            vectors = np.argsort(self.vectors[q_front, :][:, axis])
            vectors = [q_front[vec] for vec in vectors]

            m = len(vectors)

            median_index = int(m / 2)
            split_value = self.vectors[vectors[median_index]][axis]

            left, right = median_index + 1, m - 1

            while left <= right:
                mid = int((left + right) / 2)

                if self.vectors[vectors[mid]][axis] > split_value:
                    right = mid - 1
                else:
                    left = mid + 1

            median_index = left - 1

            node = Node(self.vectors[vectors[median_index]], split_value, vectors[median_index])

            if parent is None:
                self.root = node

            else:
                if l == 1:
                    parent.left = node
                else:
                    parent.right = node

            if median_index > 0:
                queueObj = QueueObj(vectors[:median_index], depth + 1, node, 1, 0)
                queue.append(queueObj)

            if median_index < m - 1:
                queueObj = QueueObj(vectors[median_index + 1:], depth + 1, node, 0, 1)
                queue.append(queueObj)

Each node stores the row vector that was used to decide split, the split value of the axis, the row in the original matrix that corresponds to this vector and the left and right sub-tree pointers.

Run-time Analysis : If we had written the above code using recursion, then our time complexity would have been given as :

T(n) = 2T(n/2) + cn*log(n),

because at each level of the tree we are sorting the rows passed to the current node from its parent node and then further splitting into two equal parts. The above recurrence relation can be solved as the following :

T(n) = O(n*log2(n))

One can theoretically solve it in O(n*log(n)) by pre-sorting the matrix and then tracking which rows went where, but the code becomes a bit complicated. Moreover, the tree construction phase is an offline process, so a factor of log(n) vs. log2(n) doesn't really effect the overall latency.

We are choosing the median as the split value. Note that multiple rows for the split axis can have the same median value, and in that case, we choose to keep all rows having less than equal to median value to the left sub-tree. This is taken care by the small piece of code using binary search to find the last row with the same median value in the sorted array.

left, right = median_index + 1, m - 1

while left <= right:
    mid = int((left + right) / 2)

    if self.vectors[vectors[mid]][axis] > split_value:
        right = mid - 1
    else:
        left = mid + 1

median_index = left - 1

Construct the tree by calling the below function :

tree = KDTree(arr)
tree.construct()

Choosing the split axis in a round-robin manner is not the only way. Other more optimum technique would be to choose the axis for which the difference between the value after the median and the value before the median is greatest. Can you guess why this is a better strategy ?

Following is the code for performing exact search on the above KD Tree :

def search(self, vector):
    node = self.root

    depth = 0
    while node is not None:
        if np.array_equal(node.vector, vector):
            return True

        axis = depth % self.vector_dim

        if vector[axis] <= node.split_value:
            node = node.left
        else:
            node = node.right

        depth += 1

    return False

The above function is as simple as doing a search on a binary tree.

In our question similarity case, the requirement is to find the nearest K questions given a customer question or questions within a specific radius only. The single closest point is a special case of this one. So we have written the function for nearest neighbor keeping in mind the generic case of K nearest neighbors. The below function can be easily modified to handle distance threshold based nearest neighbors.

We are using a Max Heap data structure to store the smallest K distances discovered so far in the tree and continue updating the heap.

Note that when we have to decide whether we should scan both the sub-trees, we use the root of the max heap (maximum distance among the smallest K distances) as the radius of the sphere to decide whether the split plane cuts through this sphere or not, because there could be some point which is at a distance less than the root of the max heap but greater than the children of the root.

def insert_distance_into_heap(self, distances, node, node_distance, k):
    if len(distances) == k and -distances[0][0] > node_distance:
        heapq.heappop(distances)

    if len(distances) < k:
        heapq.heappush(distances, (-node_distance, node.split_row_index))


def nearest_neighbor(self, vector, k):
    search_stack = [(self.root, 0)]
    distances, visited = [], set()

    while len(search_stack) > 0:
        node, depth = search_stack[-1]

        axis = depth % self.vector_dim
        child_node = None

        if vector[axis] <= node.split_value:
            if node.left is None or node.left.split_row_index in visited:
                node_distance = math.sqrt(np.sum((node.vector - vector) ** 2))

                if node.right is None or node.right.split_row_index in visited:
                    self.insert_distance_into_heap(distances, node, node_distance, k)

                else:
                    w = node_distance if len(distances) == 0 else - distances[0][0]

                    if node.split_value - vector[axis] <= w:
                        child_node = node.right

            else:
                child_node = node.left

        else:
            if node.right is None or node.right.split_row_index in visited:
                node_distance = math.sqrt(np.sum((node.vector - vector) ** 2))

                if node.left is None or node.left.split_row_index in visited:
                    self.insert_distance_into_heap(distances, node, node_distance, k)

                else:
                    w = node_distance if len(distances) == 0 else - distances[0][0]

                    if vector[axis] - node.split_value <= w:
                        child_node = node.left

            else:
                child_node = node.right

        if child_node is None or child_node.split_row_index in visited:
            visited.add(node.split_row_index)
            search_stack.pop()

        else:
            search_stack.append((child_node, depth + 1))

    distances = [(-x, y) for x, y in distances]
    distances = sorted(distances, key=lambda k: k[0])

    return distances

Again instead of recursion, we are using an iterative method for finding the K nearest neighbors. But instead of a Queue, we are using a Stack data structure, since we are going depth-wise and not level-wise. Also note that we are keeping a variable "visited" to track all those nodes for which we have completed scanning the root node as well as the left and right sub-trees of that node.

The first level of if..else condition is same as that of the exact search method, where we go to left sub-tree if the value along the split axis for the ref. point is less than equals to the split value of the node else we go right. But then for each node we check whether the node has already been "visited" or is it a leaf node. If either of them is true then we insert the distance from the node to the ref. point into our max heap, else we go either to the other side of the sub-tree (sphere cutting the splitting plane) or backtrack upwards.

To improve the search speed, instead of using the split plane, one can also use the plane along the same axis but passing through the nearest points along each side of the split plane (support planes). To use this method, we need to save the nearest points from the split plane on both sides with the node properties.

nearest planes to the split plane along the x-axis (shown in dotted lines)

class Node(object):
    def __init__(self, vector, split_value, split_row_index, left_nearest, right_nearest):
        self.vector, self.split_value, self.split_row_index = vector, split_value, split_row_index
        self.left_nearest, self.right_nearest = left_nearest, right_nearest
        self.left, self.right = None, None

Add the following in 'construct' method of the KDTree class :

a, b = max(0, int(m/ 2) - 1), min(m - 1, median_index + 1)
left_nearest, right_nearest = self.vectors[vectors[a]][axis], self.vectors[vectors[b]][axis]

node = Node(self.vectors[vectors[median_index]], split_value, vectors[median_index], left_nearest, right_nearest)

And modify the 'nearest_neighbor' function as follows :

#For scanning right sub-tree
if node.right_nearest - vector[axis] <= w:
    child_node = node.right

#For scanning left sub-tree
if vector[axis] - node.left_nearest <= w:
    child_node = node.left

To modify the above code to handle threshold based distance metric, we just have to modify the function "insert_distance_into_heap" to insert into a list instead of a heap based on some threshold.

Performance Analysis : Although exact search on the KD Tree is very fast as compared to linear search but the nearest neighbor search performance is rather poor. From several experiments the run time of nearest neighbor search is at par with the linear scan method. From a visual analysis perspective, only if the points in the KD Tree are well spaced out in the D-dimensional hyperspace, only then we gain some advantage over linear scan, else in most cases we need to search at-most O(N) nodes. In a way it is similar to clustering. Clustering is good only when the intra-cluster distances are much smaller than the inter-cluster distance.

Categories: MACHINE LEARNING, PROBLEM SOLVING

Tags: , , , , ,