Machine Learning, AI and Programming

Fast Nearest Neighbour Search - Product Quantization

In this on-going series of fast nearest neighbor search algorithms, we are going to look at Product Quantization technique in this post. In the last post, we had looked at KD-Trees, which are effecient data structures for low dimensional embeddings and also in higher dimensions provided that the nearest neighbor search radius is small enough to prevent backtracking. Product Quantization or PQ does not create any tree indexing data structure but relies on something known as Approximate Nearest Neighbor (ANN) Search to reduce the memory requirements and search time at the time of inference.

ANN refers to approaches which do not require comparing a query directly with items in the database but some approximate representation or hash for the items. For example, if we cluster the items, then compare the query with the centroids, it is an ANN method. The centroids become sort of hash for all items in that cluster. Since in KD-Tree we do direct distance computation between a query embedding and leaf nodes, hence it is not an ANN method.

K-Means Clustering

The advantage of ANN is that it does not require to store the items themselves (or their embeddings) but instead we store their hashes as they are only needed at the time of inferencing. For e.g. if average cluster size is around 10, then we have reduced the memory requirement by 10 times because now for every 10 items there is only one centroid (or hash). But we also need to keep a map from an item index to its hash value because when we query, we must return actual item indices. Thus although we compare a query with an item but instead of the original item embeddings we use the hash of the item.

The disadvantage with ANN is that there is loss in information when we compute centroid or a hash for items and thus the accuracy of the results decreases as compared to direct query to item comparison. Imagine a very bad hash function - the first 16 dimensions of 1024 sized embeddings. Although the memory requirement for storing such hash function is reduced 64 times but as you see that there could be many items which may not be similar to one another can have their first 16 dimensions same.

One possible way to overcome such situations is to use multiple independent hash functions. Error rate will decrease. For e.g. if there are M independent hash functions each with an error rate of E and we say that a query matches an item if at-least one hash function says it is a match, then the probability that it is a correct prediction is :

EM = 1-EM

Because the probability that all hash function gives error prediction is EM thus the probability that at-least one hash function is correct is 1-EM . Now assuming that E=0.4, then the probability that a single hash function gives correct results is 1-E=0.6, but with M=3 hash functions, the probability of giving correct result is now EM = 0.936

Thus instead of one hash with only the first 16 dimensions, consider using 64 hash functions corresponding to all 16-sized sub-vectors (64x16=1024).

Thus now if an hash of 16-sized sub-vector has an error rate as high as 0.8, the error rate with 64 such hash functions would be only 6.2e-7. Here we are assuming that the sub-vectors of a vector are independent of each other, which may not be the case always.

Partitioning embedding into equal parts and clustering each part independently

The idea of PQ is pretty straightforward.

  • Divide the D sized embeddings into M partitions of size D/M.
  • Cluster each of the M parts separately. Thus now we would have M different set of clusters. For e.g. if D=1024 and M=64, then:
    • The 1st set of clusters are built using the dimensions 0 to 15 for all the items.
    • The 2nd set of clusters are built using the dimensions 16 to 31 for all the items and so on.
  • The number of clusters in each set is specified to be K.
    • The number of false positives will decrease on increasing K, because the number of items corresponding to each centroid will be less.
    • Memory requirement and time taken to construct the clusters will increase with higher value of K.
  • Each set of M clusters will have K centroids. Let's number these centroids from 0 to K-1.
  • Each item is now represented as an array of M integers, each integer is between 0 to K-1.
    • For e.g. if the value of the 3rd integer is 25, it implies that the centroid id for the 3rd partition is 25 for the item.
  • Instead of storing the 1024 sized embeddings for each item, we store:
    • An array of M integers for each item. Let's call this array U.
    • A 3D array of dimensions (M, K, D/M) for storing the centroids corresponding to each M and each K. Size of array is K*D. Let's call this array V.
    • Assuming there are 1 billion items, then storing 1024-sized embeddings would take 3.8 TB of memory (assuming 32 bit float).
    • Assuming 8 bit integers, the array U would require only 60 GB memory.
    • Assuming K=220 clusters, array V would require only 4 GB memory.
    • Thus using PQ strategy, we can store the entire data of 1 billion items (U+V) required for inferencing in just 64 GB memory.

Now for any D dimensional embedding, it can be represented by M integers each from 0 to K-1, because each of the M partition for the embedding will be closest to one of the centroids in each partition. Thus effectively using M integers each from 0 to K-1, we can cover KM possible number of items.

The memory efficiency can be expressed as :

4*D*KM / (M*KM + 4*K*D)

For a query embedding, we need to follow the same process as above to retrieve the most similar items.

  • Divide the query embedding into M equal partitions of size D/M each.
  • For each partition 'm', find the square of the euclidean distances to each of the K centroids in the m-th partition of the matrix V above.
    • We will obtain a 2D matrix W of dimensions (M, K) for each query.
  • To find the distances to each item X:
    • From the matrix U, obtain all the centroid ids corresponding to the item X.
    • Sum up the squared euclidean distances corresponding to the partitions and the centroid ids from the matrix W.
    • Return the square root of the sum of distances above.
    • For e.g. if the row corresponding to item X in matrix U is as follows:

[KX(0), KX(1), ..., KX(M-1)]

Then the corresponding squared euclidean distances in W would be :

W[0][KX(0)],  W[1][KX(1)], ..., W[M-1][KX(M-1)], where W[i][j] is the distance value at (i, j)-th index in matrix W

Thus the distance of the query from item X would be:

d = sqrt(W[0][KX(0)] + W[1][KX(1)] + ... + W[M-1][KX(M-1)])

The above distance 'd' is actually the same as the euclidean distance between the query and the centroid representation for an item X. The squared euclidean distance between two vectors 'x' and 'y' is written as :

(x0 - y0)2 + (x1 - y1)2 + (x2 - y2)2 + ... + (xD-1 - yD-1)2

Now, one can re-write the above, also as :

(∑i=0 to D/M-1 (xi - yi)2) + (∑i=D/M to 2D/M-1 (xi - yi)2) + ... + (∑i=(M-1)D/M to D-1 (xi - yi)2)

Each of the above summation of D/M terms refers the quantities W[i][KX(i)], i.e

W[i][KX(i)] = ∑i=iD/M to (i+1)D/M-1 (xi - yi)2

Comparing query with M by K distance matrix

The python implementation for PQ is as follows:

import numpy as np
from scipy.cluster.vq import vq, kmeans2
from scipy.spatial.distance import cdist
from sklearn.cluster import MiniBatchKMeans

def get_kmeans_clusters(vectors, num_clusters, use_mini_batch=True):
    if use_mini_batch:
        batch_size = int(min(num_clusters/3.0+1, vectors.shape[0]))
        kmeans = MiniBatchKMeans(n_clusters=num_clusters, batch_size=batch_size, init='k-means++')
        return kmeans.cluster_centers_, kmeans.labels_
        centroids, labels = kmeans2(vectors, num_clusters, minit='points')
        return centroids, labels

class PQ(object):
    def __init__(self, num_partitions, num_codewords_per_partition):
        self.n, self.m = 0, 0
        self.num_partitions = num_partitions
        self.num_codewords_per_partition = num_codewords_per_partition
        self.pqcode = None
        self.codewords = None
    def construct(self, vectors):
        self.n, self.m = vectors.shape
        parition_dim = int(self.m / self.num_partitions)
        self.codewords = np.empty((self.num_partitions, self.num_codewords_per_partition, parition_dim), np.float32)
        self.pqcode = np.empty((self.n, self.num_partitions), np.uint8)
        for m in range(self.num_partitions):
            sub_vectors = vectors[:,m * parition_dim : (m + 1) * parition_dim]
            if sub_vectors.shape[0] == 1:
                self.codewords[m], label = np.mean(sub_vectors, axis=1), np.array([0]*sub_vectors.shape[0])
                self.codewords[m], label = get_kmeans_clusters(sub_vectors, self.num_codewords_per_partition, use_mini_batch=False)
            self.pqcode[:, m], dist = vq(sub_vectors, self.codewords[m])
    def query_count(self, query, k=5):
        parition_dim = int(self.m / self.num_partitions)
        dist_table = np.empty((self.num_partitions, self.num_codewords_per_partition), np.float32)
        for m in range(self.num_partitions):
            query_sub = query[m * parition_dim : (m + 1) * parition_dim]
            dist_table[m, :] = cdist([query_sub], self.codewords[m], 'sqeuclidean')[0]
        dist = np.sqrt(np.sum(dist_table[range(self.num_partitions), self.pqcode], axis=1))
        dist = zip(dist, range(self.n))
        dist = sorted(dist, key=lambda k:k[0])
        return dist[:min(k, len(dist))]
    def query_radius(self, query, radius=0.1):
        parition_dim = int(self.m / self.num_partitions)
        dist_table = np.empty((self.num_partitions, self.num_codewords_per_partition), np.float32)
        for m in range(self.num_partitions):
            query_sub = query[m * parition_dim : (m + 1) * parition_dim]
            dist_table[m, :] = cdist([query_sub], self.codewords[m], 'sqeuclidean')[0]
        dist = np.sqrt(np.sum(dist_table[range(self.num_partitions), self.pqcode], axis=1))
        dist = zip(dist, range(self.n))
        dist = [(x, y) for x, y in dist if x <= radius]
        return dist

Observe that we do not include the training embedding vectors as part of the class object because these are not needed at the time of inferencing. The class definition follows the same abstract pattern as defined for the KD-Tree class in our last post so that either of them can be interchangeably used.

In the 'query_count' and the 'query_radius' methods, the for-loop can be parallelized because each set of clusters is independent.

Coming to the drawbacks of PQ in comparison to KD-Tree approach are:

  • Best case run-time complexity in case of KD-Tree is O(D*logN). Whereas the best case run-time complexity in case of PQ is O(D*N)
    • When the search radius is very small we hit the optimal run-time pretty often with KD-Tree.
    • In PQ we have to do linear search for each query. When N =1 billion, this will not be scalable approach.
    • Although we can reduce search time in PQ by parallelising the code as seen earlier.
  • Accuracy of results obtained with KD-Tree is same as that of linear scan. Whereas we cannot obtain similar accuracy numbers with PQ for any random M and K values.
    • PQ is an approximate algorithm and thus we need to experiment with M and K parameters to obtain optimal performance.
    • Usually higher values of M and K will improve the performance but there must be a saturation point beyond which increasing M or K has no noticeable effect else PQ will converge to a brute force approach.
  • KD-Tree works well with lower dimensional embeddings, but for PQ to work well we need higher value of M which can be made possible by having higher embedding size.
    • For a fixed partition size e.g. 8, the value of M with 128 sized embeddings would be 16 but with 1024 sized embeddings M would be 128.
  • Euclidean distances between two items or b/w a query and item in PQ can be very different from their true euclidean distances.
    • Distance b/w two items in PQ is distance b/w their cluster centroids.
    • Cannot use any estimated euclidean distance in the actual embedding space as a threshold for search in PQ.
    • For e.g. if the L2 distance threshold used in linear search or KD-Tree was Z then the same Z will not work similarly in PQ.

The advantages of PQ over KD-Tree are quite clear:

  • KD-Tree suffers from curse of dimensionality while PQ does not.
    • With higher dimensions and large search radius, the search time in KD-Tree is often greater than linear scan.
  • Memory requirements with PQ is very small as compared to KD-Tree
    • Storing 1 billion items in KD-Tree would require around 4 TB of memory.
    • Storing 1 billion items in PQ would require only 64 GB of memory.
  • KD-Tree cannot be distributed but PQ can be distributed across multiple machines.
    • We can hold centroids in each partition on different machines if data becomes too large to fit in a single machine.
    • We can shard on either M (partitions) or K (clusters).

In order to overcome the linear run-time complexity in PQ, we can use clever tricks like doing pre-grouping of the items before the PQ step. For e.g. in our e-commerce domain, we usually pre-group the items based on their product types. Thus PQ is run independently on each product type.

While searching, we only search the PQ clusters belonging to the same product type.

def construct_product_quantizer(vectors, num_partitions, num_codewords):
    mean = np.mean(vectors) if vectors.shape[0] > 1 else np.zeros(vectors.shape[1])
    sd = np.std(vectors) if vectors.shape[0] > 1 else np.ones(vectors.shape[1])
    vectors = (vectors - mean)/sd
    k = num_codewords if vectors.shape[0] > num_codewords else 1
    pq = PQ(num_partitions=num_partitions, num_codewords_per_partition=k)
    return pq, mean, sd

def construct_product_quantizer_per_PT(vectors, product_types):
    pt_indices = collections.defaultdict(list)

    for i in range(len(product_types)):
    pool = ThreadPool(5)
    out = x: (x[0], construct_product_quantizer(vectors[x[1],:], NUM_PQ_PARTITIONS, NUM_CODEWORDS_PER_PARTITION)), pt_indices.items())
    pq = {x[0]:x[1] for x in out}
    return pq

Instead of calling the function 'construct_product_quantizer' we generally call the later function 'construct_product_quantizer_per_PT'.

Observe that in the 'construct_product_quantizer' we are standardizing the vectors for each product type by its mean and standard deviation.

This is done because the mean and standard deviation for each product type cluster would be different and thus in some product types the items can be very far away from one another and in another product type they can be very close to each other. Thus if we use a single distance threshold to match query with items, then for certain product types we will never find a match.

During search we also need to standardize the query vector with the mean and of its product type.

Some very useful references for PQ:


Tags: , , , , , , , ,

Leave a Reply