Machine Learning, AI and Programming

Designing large scale similarity models using deep learning

Finding similar texts or images is a very common problem in machine learning used extensively for search and recommendation. Although the problem is very common and has high business value to some organisations, but still this has remained one of the most challenging problems when the database size is very large such as >50GB and we do not want to lose on precision and recall much by retrieving only 'approximately' close results.

In this post we are going to look how to find similar items using deep learning techniques along with standard engineering best practices. But most importantly we are going to tackle this for a dataset that is larger than the available CPU memory of a 8GB or even a 16GB machine.

A generic approach to solve this kind of problem is to:

  1. For text data.
    • Compute representation for each item using either PCA, LDA or TF-IDF weighted word vectors etc.
    • Either do a hierarchical tree clustering or create a fast nearest neighbour lookup index such as a KD-Tree with the representations.
    • For a new incoming item, compute its representation and then search the clusters or the KD-Tree for the most similar items within certain distance threshold.
    • The run-time complexity for lookup in both approaches ideally would be O(D*logN) where N is the number of items and D is the dimension of the representations.
  2. For image data.
    • Compute representation for each image using histogram or hashing based technique.
    • 'pHash' would be a desirable choice or in combination with multiple hashing algorithms such as 'aHash' and 'dHash'.
    • More modern solutions involve encoding the image into embeddings using Autoencoders. Denoising and Variational Autoencoders does a good job.
    • The remaining steps for index creation and lookup would be same as for text data.

Heirarchical Tree Clustering


KD-Tree Indexing

The above approach is generally suitable when we are in a purely unsupervised setting i.e. we do not know which items are similar and which are not. For e.g. in question-answer mining, we cannot know without looking at all the questions themselves, which of them are similar.

Moreover learning good representations from un-supervised setting require lots of data.

But assuming that we have supervised data or some high-level abstraction which places similar items under one group. For e.g. in e-commerce data, 'product type' such as 'Laptops', 'T-Shirts', 'Mobile Phones' etc. is a high-level abstraction because items under 'Laptops' would be more similar to each other as compared to an item from 'Laptop' and an item from 'T-Shirt'.

Similar Items based on Product Type

The level of abstraction can go to any depth, for e.g. different color and size variants for the same t-shirt or different t-shirts from the same brand and so on. It depends on the kind of problem and business requirement till what granularity of abstraction we want.

Unlike in unsupervised similarity models where each feature of an item (words, n-grams etc.) has equal weightage or best we can do is to give TF-IDF weighting, if we train a supervised model with pairs of items and learn which items are similar and which are not, then the model will learn to give appropriate weightage to each feature of the item as to how important that feature is when finding the similar items.

For e.g. if two t-shirts have the following titles:

"Katso Men's Cotton T-Shirt"

"Maniac Men's Cotton T-Shirt"

Although 3/4-th of the words are same but these items are not similar because they belong to different brands "Katso" and "Maniac" respectively. Thus one word makes all the difference here. Only a supervised model can learn that the weightage of the words "Katso" and "Maniac" matters.

Thus we train a model using pairs of items by giving label 1 to a pair in which the two items are similar and label 0 to a pair where the two items are dissimilar. Train a Siamese Network.

  • For text data we use an Embedding+BiLSTM based Siamese network.
  • For image data we use a CNN+MaxPool based Siamese network.
  • Merge the shared layers of the Siamese network using L2 distance.
  • Output is the sigmoid function of the weighted L2 distance.
  • Model is trained using "Adam" optimizer and 'binary_crossentropy" loss function.
  • Batch size for text is kept as 32 whereas for image it is kept at 64.

BiLSTM based Text Siamese Model (This is just an example code):

def init_model(self):
    input1 = Input(shape=(self.max_words,))
    input2 = Input(shape=(self.max_words,))

    embed_layer = Embedding(input_dim=self.vocab_size + 1, output_dim=self.embedding_size, input_length=self.max_words, mask_zero=True)

    embed1 = embed_layer(input1)
    embed2 = embed_layer(input2)

    bilstm_layer = Bidirectional(LSTM(units=64))

    bilstm_w1 = bilstm_layer(embed1)
    bilstm_w2 = bilstm_layer(embed2)

    merge = Lambda(lambda x: K.sqrt(K.maximum(K.sum(K.square(x[0]-x[1]), axis=1, keepdims=True), K.epsilon())))([bilstm_w1, bilstm_w2])
    out = Dense(1, activation="sigmoid")(merge)

    self.model = Model([input1, input2], out)
    self.model.compile(optimizer="adam", loss='binary_crossentropy', metrics=['accuracy'])

CNN base Image Siamese Model (This is just an example code):

def get_shared_model(image_shape):
    input = Input(shape=image_shape)
    n_layer = input
    n_layer = Conv2D(filters=64, kernel_size=(3, 3), activation='relu')(n_layer)
    n_layer = BatchNormalization()(n_layer)
    n_layer = MaxPooling2D(pool_size=(2, 2))(n_layer)
    n_layer = Conv2D(filters=128, kernel_size=(3, 3), activation='relu')(n_layer)
    n_layer = BatchNormalization()(n_layer)
    n_layer = MaxPooling2D(pool_size=(2, 2))(n_layer)

    n_layer = Conv2D(filters=256, kernel_size=(3, 3), activation='relu')(n_layer) 
    n_layer = BatchNormalization()(n_layer) 
    n_layer = MaxPooling2D(pool_size=(2, 2))(n_layer)
    n_layer = Conv2D(filters=256, kernel_size=(3, 3), activation='relu')(n_layer)
    n_layer = BatchNormalization()(n_layer)
    n_layer = MaxPooling2D(pool_size=(2, 2))(n_layer)

    n_layer = Flatten()(n_layer)
    n_layer = Dense(128, activation='linear')(n_layer)
    n_layer = BatchNormalization()(n_layer)
    model = Model(inputs=[input], outputs=[n_layer])
    return model

def init_model(self):
    image_shape = (IMAGE_HEIGHT, IMAGE_WIDTH, 3)

    input_a, input_b = Input(shape=image_shape), Input(shape=image_shape)

    shared_model = get_shared_model(image_shape)

    shared_model_a, shared_model_b = shared_model(input_a), shared_model(input_b)

    n_layer = Lambda(lambda x: K.sqrt(K.maximum(K.sum(K.square(x[0]-x[1]), axis=1, keepdims=True), K.epsilon())))([shared_model_a, shared_model_b])
    n_layer = BatchNormalization()(n_layer)

    out = Dense(1, activation="sigmoid")(n_layer)

    self.model = Model(inputs=[input_a, input_b], outputs=[out])

    adam = optimizers.Adam(lr=0.001)
    self.model.compile(optimizer=adam, loss="binary_crossentropy", metrics=['accuracy', km.precision(label=0), km.recall(label=0)])

Most importantly we need to select the positive and the negative pairs in such a way that the model effectively learns which words or patches of image are more important to an item for finding other similar items.

  1. Find unsupervised representation for each item such as word embeddings for text and perceptual hashes or auto-encoder representations for images.
  2. Select positive pair of items such that the two items belong to the same true abstract group such as same brand or product type but they have a high a distance in the unsupervised representation space.
    1. Find pairwise euclidean distances between unsupervised item representations.
    2. Select the pairs from same abstract group with the highest distances.
  3. Select negative pair of items such that the two items have minimum 'distance' in the unsupervised representation space but belong to different true abstract groups.
    1. Index the unsupervised item representations in a KD-Tree.
    2. Index only one item from each abstract group.
    3. For each pair, find the other pair which is nearest to the item in the KD-Tree.
    4. Since there is only one item from each abstract group, fetching 2 items will ensure that at-least one of them belongs to a different group.
  4. Train the pairs using the Siamese network as shown above.
  5. Finally, we can obtain the item representations from the layer before the L2 merge layer in the Siamese network.
  6. Once we obtain the item representations, the indexing part and the lookup part remains the same as any unsupervised model.

Siamese Network + KD Tree for Inferencing

Obtain item embeddings from BiLSTM Siamese Network:

def get_embeddings(self, X):
    embeddings = K.function([self.model.layers[0].input, self.model.layers[1].input], [self.model.layers[3].get_output_at(0)])
    return embeddings([X, X])[0]

Note that we could have also used the Siamese model at the time of lookup for the nearest matching items. For e.g. Pair the query item with all other items in our training corpus and pass through the Siamese network and let the network predict 1 or 0. If the prediction is 1 then an item is similar to the query item else it is not similar.

But there are obvious drawbacks to this:

  1. We need to pair up the query item against all other items. So if there are 30 million items, then for each query, we need to make 30 million comparisons to find the similar items.
  2. Instead of embeddings we need to store the items themselves which involves storing all kinds of metadata along with it, which are obviously more heavy in size. Embeddings acts as a dimensionality reduction step for us.

But how can we ensure that two items which are predicted to be similar by the Siamese network are also similar when the Euclidean distance between their representations are within a radius and vice versa ?

Recall that while training the Siamese network we had used L2 distance to merge the two inputs i.e. the outputs of the last layer before the merge layer. These outputs are also the item representations. After L2 merge, these output is passed through a weighted sigmoid layer with weights say W and bias B.

Thus if the inputs to the L2 distance layer were U1 and U2 and the output from the L2 layer is X, then:

X = ||U1-U2||2

Then the input to the sigmoid layer is Y = W*X+B

Thus the sigmoid layer output is Z = (1.0 + exp(-Y))-1

We know that the model predicts an output of 1 when the sigmoid layer output is greater than 0.5, thus the euclidean distance X must satisfy:

(1.0 + exp(-Y))-1 > 0.5

One of W or B has to be negative else the model will not learn any positive examples. If the weight W is negative, it implies that the slope of the curve Y=W*X+B is negative, i.e. on increasing X, Y decreases and thus the sigmoid output Z decreases.

Thus X is inversely proportional to the sigmoid output Z when W is negative. In this case, we can find an upper bound on X based on the above equation as:

X < -B/W

When W < 0. Blue curve is Y and Orange curve is Z.

When W > 0, X is directly proportional to sigmoid output Z. In this case we find a lower bound on X i.e. X > -B/W.

When W > 0. Blue curve is Y and orange curve is Z.

In the 1st case, we can directly query the KD-Tree to return similar items within a query radius of X, because X is the upper bound on similarity measure.

Whereas in the 2nd case, items returned by the KD-Tree within a query radius of X are actually the dissimilar items, so we need to take the complement of the results by taking a set difference from all the items.

If we want to apply a different threshold T other than 0.5, then the condition becomes:

(1.0 + exp(-Y))-1 > T

Let Q = -(B+ln((1.0/T)-1))/W

X < Q if W < 0 else X > Q

Thus we can simulate the output of the Siamese network just by using the item representations, and the weight and the bias of the output sigmoid layer. Moreover we can index the item representations into a KD-Tree which reduces our lookup time from O(D*N) to O(D*logN) (in an ideal case).

Python method to get distance threshold for similarity/dissimilarity for the BiLSTM Siamese model:

def get_distance_threshold(self, thres):
    weight, bias = self.model.model.layers[5].get_weights()
    return -(float(bias[0]+math.log((1.0/thres)-1.0)))/weight[0][0]

Now let's checkout some of the challenges in the implementation discussed so far:

  • Keras '' with all data loaded into memory together will cause memory error.
    • The CPU limitation is 16 GB but the data size is 50 GB.
    • Use 'model.fit_generator' instead of '' in Keras. Also the training and validation data needs to be supplied through a Python generator to 'fit_generator' method.
    • Using a batch size of 64 will drastically reduce the memory requirements.

Using generators for reading data and training deep learning models (This is just an example code):

def fit(self):

    callbacks = [
        EarlyStopping(monitor='val_loss', patience=3),
        ModelCheckpoint(filepath=self.model_file_path, monitor='val_loss', save_best_only=True),

    if self.use_generator:
        train_num_batches = int(math.ceil(float(self.training_samples)/self.batch_size))
        valid_num_batches = int(math.ceil(float(self.validation_samples)/self.batch_size))

        self.model.fit_generator(dg.get_data_as_generator(self.training_samples, 'train', batch_size=self.batch_size),
                                 validation_data=dg.get_data_as_generator(self.validation_samples, 'validation', batch_size=self.batch_size),
                                 epochs=3, verbose=1, use_multiprocessing=True)
        X_train, y_train = dg.get_data_as_vanilla(self.training_samples, 'train')
        X_valid, y_valid = dg.get_data_as_vanilla(self.validation_samples, 'validation'), y_train, 
                       validation_data=(X_valid, y_valid), 
                       epochs=3, verbose=1, shuffle=True)
  • Reading and saving data in in-memory data structures will cause memory error.
    • Generally we will be reading the data and pre-process them in-memory and also store them in in-memory data structures like a Numpy array.
    • Assuming a numpy array of size 256 is used and number of items is 30 million, thus the total memory requirement would be around 57 GB.
    • Use disk based database storage such as MongoDB or Pytables.
    • Pytables is much convenient as its syntax is similar to numpy array but persists data in disk and reads from disk.
def insert_embeddings_pytables(self, batch_size=25000):
        embeds_file = tables.open_file('data/embeddings.h5', mode='w')
        atom = tables.Float32Atom()
        embeds_arr = embeds_file.create_earray(embeds_file.root, 'data', atom, (0, 128))

        sents_arr_file = tables.open_file('data/sent_arrays.h5', mode='r')
        sents_arr =

        n = len(sents_arr)
        num_batches = int(math.ceil(float(n)/batch_size))

        for m in range(num_batches):
            start, end = m*batch_size, min((m+1)*batch_size, n)
            embeds = self.model.get_embeddings(sents_arr[start:end,:])


def fetch_embeddings_pytables(self, item_indexes=None):
        embeds_file = tables.open_file('data/embeddings.h5', mode='r')
        embeds_arr =

        if item_indexes is not None:
            output = np.array([embeds_arr[i] for i in item_indexes])
            output = np.array([embeds_arr[i] for i in range(len(embeds_arr))])

        return output

  • KD-Tree problems.
    • To create the KD-Tree we need to load all the embeddings at-once in memory, because the splitting axis and the split point is computed using all the data. For a 1 billion items we need 1TB RAM !!!
      • We can incrementally add embeddings to the tree similar to any other tree insertion algorithm.
      • But the tree can quickly become un-balanced and may affect the search time adversely.
      • The structure of the tree depends on the order in which the items are inserted into the tree.
    • When the dimension of the embeddings exceed 20, nearest neighbour search in a KD-Tree becomes almost comparable to linear search.
      • Although we are using a dimension of 128 but our query radius is really small (i.e. at a confidence of almost 99%), thus the KD-Tree performance is very good as it avoids doing backtracking to opposite branches.
    • KD-Tree in its original implementation cannot be distributed across machines. When data size becomes too large KD-Tree will not fit in single machine and needs to be stored in distributed manner.
      • Build a separate KD-Tree for each abstract group, such as a KD-Tree for each product type of items.
      • Apart from enabling us to store different KD-Trees in different servers it also enables faster query because now each KD-Tree indexes fewer items.
    • No incremental updates possible or merging two KD-Trees efficiently without re-building the tree.
    • Python implementations of KD-Tree also stores the embeddings in the leaves of the tree. Instead of storing the embeddings themselves we can store the index reference of the embeddings, while the embeddings is stored in a Pytables data structure on disk.

In the next few posts I will be exploring alternative Fast Spatial Indexing data structures and discuss their implementations.

The full codes are available in my Github repository:


Tags: , , , , , , , , , ,