Stokastik

Machine Learning, AI and Programming

Fast Nearest Neighbour Search - KD Trees

In a few of my earlier posts "designing Q&A retrieval system" and "designing e-commerce system for similar products using deep learning", one of the primary goals was to retrieve the most similar answers (responses) and the most similar e-commerce products respectively in a fast and scalable way. The later post was pretty much generic i.e. find dense representations for each item (questions, product images etc.) using un-supervised and supervised learning algorithms such as:

Unsupervised - PCA, LDA, Auto-encoders, TF-IDF weighted word vectors etc.

Supervised - Classification Neural Networks, Siamese Deep Neural Networks etc.

For neural network based approaches such as autoencoders, word vectors and siamese network, we like to refer these representations as embeddings. Once we find these embeddings we usually store these embeddings in some data structure such that when we query with a new embedding, we can retrieve the closest embeddings to the query in terms of a distance metric such as euclidean (L2) distance, cosine, L1 distance etc.

The best part of the above approach is that we can use the same approach with both un-supervised and supervised data.

In the rest of the post we will assume that we are going to use euclidean distance as the metric.

The challenge is in retrieving the most similar embeddings for a query in a fast and scalable way. The brute force approach seems to be straightforward. Store the embeddings in a 2D array and then scan the query embedding with each row of this 2D array. Compute the distances and sort them lowest to highest. If the 2D array is NxD dimensional then the run-time complexity of this approach scales as O(ND).

But in practical scenarios, N is in millions or billions. Thus for each query we need to scan all N rows. This run-time is not effecient enough for large N. One of the ways in which we have tackled this problem of retrieving the most similar items in less than O(ND) complexity is by using a data structure known as the KD-Tree. We are going to discuss more about KD-Tree in details in this post.

Let's consider the 1-D space. For e.g., given N points in a 1-D space, then if the points are sorted, searching for a point (if it exists or not) takes O(logN) time complexity using plain binary search. Searching for the nearest point to a query point is also O(logN) as shown below.

A python implementation to find the nearest element given a query point (assuming points are sorted in 1-D space) :

def get_nearest(arr, ref):
    left, right = 0, len(arr) - 1

    while left <= right:
        mid = int((left + right) / 2)

        if arr[mid] > ref:
            right = mid - 1
        else:
            left = mid + 1

    if left > 0 and abs(arr[left - 1] - ref) < abs(arr[left] - ref):
        return arr[left - 1]
    else:
        return arr[left]

The code searches for the interval in the array, in which our reference point lies, and then returns either the left or the right side value of the interval depending on which one is more closer to our reference point.

The case for 1-D is quite simple. Now consider points in 2-D. In a 1-D space, there are only two ways of putting a point w.r.t. another point i.e. left or right. In 2-D space, there would be 4 possible ways left-top, left-bottom, right-top, right-bottom. Let's say that we use 2 binary trees - one for the x-axis and one for the y-axis. Using the 1-D approach above, we find the nearest point in both the axes. But the nearest point in both the axes need not be same or for that matter none of them need not be the nearest point to the query in terms of euclidean distance.

Because the euclidean distance d = sqrt((x-x1)2 + (y-y1)2)

Even if (x-x1) is small, (y-y1) can be large and vice versa.

Yet in another approach we can use a Quad-Tree instead of binary tree i.e. each node will have upto 4 children corresponding to the 4 quadrants. But although we can traverse the quadrants in a hierarchy to reach the smallest quadrant containing our query point, but there is no guarantee that the closest point to the query will be in the same quadrant. The worst case time complexity will always be O(ND).

KD-Tree is more like a generic data structure for D > 1 dimensions. Although the worst case time complexity in a KD-Tree is O(ND) but depending on how "good" the structure of the tree is, we can find the nearest point in O(D*logN) best case time complexity.

Next we show one of the simplest way in which we can construct a KD-Tree on a 2-D space.

The idea is to first divide the points in two parts along the x-axis (points which lie to the left and points which lie to the right of x-axis). Then for each left and right side, split the points further into two parts but this time along the y-axis. Thus now each left and right parts are further split into up and down sub-parts. Repeat the process in round-robin fashion until no further splitting is possible.

KD Tree splitting in 2 dimensions (Source : research-gate)

The split-point is chosen to be the point having the median value along the split-axis.

Once the tree is constructed, searching for a point in the tree takes O(D*log(N)) time complexity. Why ?

If the root is the query point, then it is trivial, else check with the value along the split axis for the root, i.e. the x-axis. If the value along the x-axis of the query point is less than equal to the split value at the root then search in left sub-tree else search in the right sub-tree.

For e.g. to search for (60, 80) in the below tree, one would compare 60 with 51 (x-axis value of root) and then decide that 60 > 51 and we should go to right sub-tree. Then compare 80 with 70 (the y-axis value of the root at right sub-tree) and decide that 80 > 70 and thus again go right and so on.

KD-Tree constructed out of 2D points.

Note that by round-robin splitting, we are trying to ensure that the tree is balanced as much as possible. Thus with N points, the maximum height is O(logN). Since at each level of the tree we search at-most one point and each point is D (=2) dimensional, thus the time complexity for exact search is O(D*log(N)). But round-robin splitting may not always be optimal (what if all the points or most points lie along the x-axis ?)

Searching the nearest point is a bit tricky.

Let's traverse the tree as we would do in the case of exact search. For example, if our query point is (50, 2) (shown in green below), then by traversing the above tree we would land up at the bounding box with (10, 30) on the left boundary and (25, 40) on the top boundary. Now is it true that either of (10, 30) or (25, 40) is the closest point to (50, 2) ? No, because the closes point is (55, 1) which is inside the bounding box to the right. That means we simply cannot traverse the tree in one of left or right direction in each step to get to the solution as in the case of exact search.

Bounding Box

To solve this problem, we need to know, whether we must search only one side of a sub-tree or both the sides. But how to know that ?

Since we reached the destination by traversing last through (10, 30), we compute the euclidean distance of (50, 2) from (10, 30), which is 48.83. This distance becomes the radius of a sphere centered at the query. Now, if this sphere cuts through the plane of the split axis (x=10), then we also need to scan the left side of the plane because some region on the left side of the plane can come inside this sphere. Which is true since the perpendicular distance of the ref. point from the plane x=10 is 40 < 48.83, implying that this sphere cuts through the x=10 plane.

Fortunately (1, 10) is at a distance of 49.65, which is greater than 48.83 and thus our best guess still remains (10, 30).

Once we have looked into both branches of a sub-tree, we start to backtrack up the tree. (25, 40) is split on the y-axis (y=40). Since the perpendicular distance from y=40 to (50, 2) is 38 < 48.83 (the radius of the smallest sphere seen so far), it implies that the sphere cuts through this plane too and we need to scan the other side of y=40 also. Repeating the same for the right sub-tree of (25, 40), we find that the point (50, 50) is closer i.e. distance=48 < 48.83. Thus our new solution becomes (50, 50).

Once we are done checking the left and right sub-tree of (25, 40), we check the distance of the query with (25, 40). And it is 45.49 < 48. Thus (25, 40) is our current best solution.

Next up (51, 75) is split on the x-axis by the x=51 plane. Perpendicular distance of (50, 2) from x=51 plane is 1 which is less than 45.49 (radius of smallest sphere seen till now). Thus we need to scan the right sub-tree rooted at (51, 75). Repeat the above process and we eventually find the true closest point (55, 1). Note that by the time we discover our solution (55, 1), we have scanned each and every point in the tree. Thus nearest neighbor search in a KD-Tree has the worst case complexity of O(ND), similar to linear scan.

If the query had been say (12, 33), then we need not have gone towards the right sub-tree of (25, 40) or right sub-tree of (51, 75), because the radius of sphere from (10, 30) to (12, 33) is 3.6, which is less than the perpendicular distances of (12, 33) from y=40 (7) and x=51 (39).

Here is a small animation taken from wikipedia explaining how we go about searching the tree :

Nearest Neighbor Search

Regarding the algorithm to split a node at each level, we had used a round-robin way to select the axis to split on. Although this technique produces a balanced tree of height O(logN) but it might not be the best algorithm to make searching more effecient. One thing that we should target to reduce the search time is prevent backtracking i.e. going to the other branch of a sub-tree as much as possible.

If the points in the tree are "well spaced out" we can achieve a run-time complexity of near O(D*logN). Some strategies to achieve better splits at each node:

  • Select the axis in which the difference between minimum and maximum value along that axis is maximum. Choose the split point to be the middle point of min and max or the median.
  • Select the axis which the variance of the values along that axis is maximum. Choose the split point to be the median along this axis.
  • Select the axis which the difference between the point after the median and the point before the median along that axis is maximum. Choose the split point to be the median along this axis.

Time to code up !!!

We will be using a queue to construct the KD-Tree in a "Breadth First" manner instead of using recursion, because many times when the data size is large i.e. in millions, recursive solution gives "maximum recursion depth" errors.

Create 3 classes : Node, LeafNode and KDTree. "Node" class stores all the internal nodes whereas "LeafNode" stores the leaf nodes.

Let's say that we are going to use the min-max splitting algorithm mentioned above and use the median of the best axis as the split value. For this :

  1. Find the minimum and maximum along each axis for all embeddings.
  2. Find the axis 'A' with the maximum difference between min and max in step 1 above.
  3. Choose the split value 'V' as the median of all embeddings along the axis A in step 2.
  4. For all embeddings with values along the axis A lesser than or equal to the split value V, put them in the left sub-tree
  5. For all embeddings with values along the axis A greater than the split value V, put them in the right sub-tree
  6. For left and right sub-tree repeat from step 1
  7. If all the embeddings are equal, add them to a LeafNode.
  8. If the number of embeddings in a sub-tree is less than the specified 'leaf_size' parameter, then add them to a LeafNode.
import numpy as np
import time, math, heapq, tables
from collections import deque
from sklearn.metrics.pairwise import euclidean_distances
from six import string_types

valid_fn_names = ['max_min_mid_split', 'max_min_median_split', 'max_variance_split']

def max_min_mid_split(vectors):
    maxs = np.max(vectors, axis=0)
    mins = np.min(vectors, axis=0)

    split_axis = np.argmax(maxs-mins)
    split_val = 0.5 * (maxs[split_axis] + mins[split_axis])
    
    return split_axis, split_val


def max_min_median_split(vectors):
    maxs = np.max(vectors, axis=0)
    mins = np.min(vectors, axis=0)

    split_axis = np.argmax(maxs-mins)
    split_val = np.median(vectors[:,split_axis])
    
    return split_axis, split_val


def max_variance_split(vectors):
    variances = np.var(vectors, axis=0)

    split_axis = np.argmax(variances)
    split_val = np.median(vectors[:,split_axis])
    
    return split_axis, split_val


def get_split(vectors, algorithm='max_min_median_split'):
    if isinstance(algorithm, string_types) and algorithm in valid_fn_names:
        return eval(algorithm)(vectors)
    return None, None


class Node(object):
    def __init__(self, split_axis=None, split_val=None):
        self.split_axis = split_axis
        self.split_val = split_val
        self.left, self.right = None, None
        
        
class LeafNode(object):
    def __init__(self, indices):
        self.indices = indices
        
        
class KDTree(object):
    def __init__(self, vectors, leafsize=10, algorithm='max_min_median_split'):
        self.leaf_size = leafsize
        self.tree = None
        self.vectors = vectors
        self.algorithm = algorithm
        
    def construct(self):
        root_indices = np.arange(self.vectors.shape[0])

        if self.vectors.shape[0] <= self.leaf_size:
            self.tree = LeafNode(root_indices)
        else:
            self.tree = Node()
            queue_obj = deque([(self.tree, root_indices, None, None)])

            while len(queue_obj) > 0:
                curr_obj, indices, parent_obj, direction = queue_obj.popleft()

                if isinstance(curr_obj, Node):
                    split_axis, split_val = get_split(self.vectors[indices,:], self.algorithm)
                    
                    if split_axis is None:
                        return "Incorrect splitting algorithm specified"
                    
                    vec = self.vectors[indices, split_axis]

                    l_indices = indices[np.nonzero(vec <= split_val)[0]]
                    r_indices = indices[np.nonzero(vec > split_val)[0]]
                    
                    if len(r_indices) == 0 or len(l_indices) == 0:
                        if parent_obj is not None:
                            if direction == 0:
                                parent_obj.left = LeafNode(indices)
                            else:
                                parent_obj.right = LeafNode(indices)
                        else:
                            self.tree = LeafNode(indices)
                            break
                        
                    else:
                        curr_obj.split_axis = split_axis
                        curr_obj.split_val = split_val

                        if len(l_indices) <= self.leaf_size:
                            l_node_obj = LeafNode(l_indices)
                        else:
                            l_node_obj = Node()

                        if len(r_indices) <= self.leaf_size:
                            r_node_obj = LeafNode(r_indices)
                        else:
                            r_node_obj = Node()

                        curr_obj.left, curr_obj.right = l_node_obj, r_node_obj

                        queue_obj.append((l_node_obj, l_indices, curr_obj, 0))
                        queue_obj.append((r_node_obj, r_indices, curr_obj, 1))

Each element in the queue of the 'construct' method stores the current node object, the indices of the embeddings, the parent node object of the current object and the direction of the current object from the parent i.e. left (0) or right (1).

Storing the parent node object in the queue is important because in case the current node object has all embeddings same, then instead of an internal Node we store the current object as a LeafNode.

Run-time analysis of the 'construct' method :

The number of nodes in the tree is O(N/leaf_size), because the number of leaf nodes is at-max N/leaf_size and the total number of node will be thus 2N/leaf_size (M + M/2 + M/4 + ... + 1 <= 2M). Thus the height of the tree is bounded above by O(log(N/leaf_size)).

The time complexity for splitting the nodes at each level is O(ND) because if there are M nodes at each level then each node has taken O(D*N/M) time for splitting. Since the height of the tree is O(log(N/leaf_size)) thus the overall time complexity is O(N*D*log(N/leaf_size)).

Construct the tree by calling the below function :

tree = KDTree(vectors=vectors, leafsize=100)
tree.construct()

Next we move on to write our retrieval methods. We want to retrieve most similar items in two possible ways.

  1. Retrieve the K most similar items for a query item.
  2. Retrieve the most similar items that are within a radius of R from the query point.

For the 1st part, we will keep a stack of the visited nodes along a path from the root towards a LeafNode. For a Node, we compare the Node split value with the value along the split axis for the query. If the value along the split axis for the query is less than or equal to the Node split value, we go towards the left sub-tree else we go towards the right sub-tree. Repeat this until we reach a LeafNode.

When we reach a LeafNode, we compute the euclidean distances between the query and all the points in the LeafNode. We add the distances of the compared points into a K-sized Max-Heap Priority Queue. This bounded priority queue ensures that at any time it stores the (at-most) nearest K points seen so far.

But as we have seen earlier that all the closest points may not be in the same LeafNode, we might want to also traverse towards the sibling LeafNode in such scenarios depending on whether the circle centered at the query intersects the split axis.

Thus for a given Node if we have already visited the left-subtree as dictated by the split axis and the split value, then depending on the maximum distance among the nearest K points i.e. the root of the Max-Heap Priority Queue or if we have found less than K nearest points so far, we decide to also visit the right-subtree or not. Here is the method for this approach:

def query_count(self, query_vector, k=5):
    max_heap, visited = [], set()
    node_stack = [self.tree]

    while len(node_stack) > 0:
        curr_obj = node_stack[-1]

        if isinstance(curr_obj, LeafNode):
            distances = euclidean_distances([query_vector], self.vectors[curr_obj.indices,:])[0]
            for dist, idx in zip(distances, curr_obj.indices):
                if len(max_heap) < k:
                    heapq.heappush(max_heap, (-dist, idx))
                else:
                    if dist < -max_heap[0][0]:
                        heapq.heappop(max_heap)
                        heapq.heappush(max_heap, (-dist, idx))

            visited.add(curr_obj)
            node_stack.pop()

        else:
            split_axis, split_val = curr_obj.split_axis, curr_obj.split_val

            if query_vector[split_axis] <= split_val:
                if curr_obj.left not in visited:
                    node_stack.append(curr_obj.left)
                else:
                    max_dist = -max_heap[0][0]
                    if (max_dist > abs(query_vector[split_axis]-split_val) or len(max_heap) < K) and curr_obj.right not in visited:
                        node_stack.append(curr_obj.right)
                    else:
                        visited.add(curr_obj)
                        node_stack.pop()

            else:
                if curr_obj.right not in visited:
                    node_stack.append(curr_obj.right)
                else:
                    max_dist = -max_heap[0][0]
                    if (max_dist > abs(query_vector[split_axis]-split_val) or len(max_heap) < K) and curr_obj.left not in visited:
                        node_stack.append(curr_obj.left)
                    else:
                        visited.add(curr_obj)
                        node_stack.pop()

For the 2nd approach of retrieving the most similar items, it is straightforward. Unlike the 1st approach, we track the nearest points in an array instead of a max-heap priority queue. Also for a given Node if we have already visited the left-subtree as dictated by the split axis and the split value, then depending on the radius parameter, we decide to also visit the right-subtree or not.

Here the code for the 2nd part:

def query_radius(self, query_vector, radius=0.1):
    output, visited = [], set()
    node_stack = [self.tree]

    while len(node_stack) > 0:
        curr_obj = node_stack[-1]

        if isinstance(curr_obj, LeafNode):
            distances = euclidean_distances([query_vector], self.vectors[curr_obj.indices,:])[0]
            for dist, idx in zip(distances, curr_obj.indices):
                if dist <= radius:
                    output.append((dist, idx))

            visited.add(curr_obj)
            node_stack.pop()

        else:
            split_axis, split_val = curr_obj.split_axis, curr_obj.split_val

            if query_vector[split_axis] <= split_val:
                if curr_obj.left not in visited:
                    node_stack.append(curr_obj.left)
                else:
                    if radius >= abs(query_vector[split_axis]-split_val) and curr_obj.right not in visited:
                        node_stack.append(curr_obj.right)
                    else:
                        visited.add(curr_obj)
                        node_stack.pop()

            else:
                if curr_obj.right not in visited:
                    node_stack.append(curr_obj.right)
                else:
                    if radius >= abs(query_vector[split_axis]-split_val) and curr_obj.left not in visited:
                        node_stack.append(curr_obj.left)
                    else:
                        visited.add(curr_obj)
                        node_stack.pop()
    return output

The 'leaf_size' parameter controls the trade-off between the size of the tree (also the run-time complexity of 'construct') and the speed of searching the nearest points. If the 'leaf_size' parameter is small then the time taken to construct the KD-Tree will be high but searching should "ideally" be fast as the number of linear scans are reduced (which is proportional to 'leaf_size').

Similarly if the 'leaf_size' parameter is large then the time taken to construct the KD-Tree will be less but search speed might slow down.

With the python implementation, the above is not necessarily the case always. That is why in the earlier paragraph I had mentioned "ideally".

Let us say that in the overall search process for a single query we would be scanning 100K points. Now if the 'leaf_size' is 100 then we would be going into 1000 such LeafNodes. Thus the 'euclidean_distance' method is called 1000 times on batches of 100 each. Whereas if the 'leaf_size' is 1000, then the 'euclidean_distance' method is called 100 times on batches of 1000 each.

The time taken in these two scenarios is not the same because the time taken by 'euclidean_distance' method on a batch of 1000 is not 10 times that of the time taken on a batch of 100. It must be somewhere between 2 to 3 times only. 'euclidean_distance' method takes help of fast matrix libraries to speed up computations. Thus:

Time taken for linear scan with 'leaf_size' of 1000 x 100 < Time taken for linear scan with 'leaf_size' of 100 x 1000

Here is a little chart of the time taken for 'euclidean_distance' method with different 'leaf_size' values. Note that the distribution is not exactly linear.

Distribution of time taken for 'euclidean_distance' method against the 'leaf_size' parameter

Thus with the above implementation it is better to go with a larger 'leaf_size' value (somewhere between 1000 to 5000 for a million data points).

Refer to my earlier post on "Designing large scale similarity systems using deep learning" for drawbacks with the KD-Tree approach. In the next post we will be looking at Product Quantization (PQ) for fast nearest neighbors. While KD-Tree NN search is a deterministic algorithm, search in PQ is an approximation which is a trade-off between speed and memory with accuracy of results.

References:

Categories: MACHINE LEARNING, PROBLEM SOLVING

Tags: , , , , , ,

Leave a Reply