Stokastik

Machine Learning, AI and Programming

Speeding up with Cython

While I was working with the R programming language, I always somehow found it to be slow coming from a JAVA/C++ background. Then I discovered the "RCpp" package which allowed me to write C++ codes and call them from R. It greatly improved my program speed by an order of 50-100x. Most of the coding that I did was in C++ while the interface remained in R. Since I moved to Python, I found that using libraries like "NumPy" greatly improves the performance of any code as compared to doing the calculations manually in a for-loop. But then I always wanted the raw power of C/C++ in my Python codes.

What hold me back was the ease of coding in Python. I was being more focussed on solving the problem rather than looking to gain a 10-20 ms boost in speed by optimizing the code (assuming the algorithm is optimized). But once I started to do production deployment, I realized that multiple such codes are part of a pipeline, thus the delay adds up.

Cython is the 'C equivalent' of Python. Every Python program is a valid Cython program (but the reverse is not true).

  • Python is an interpreted programming language which means it is dynamically typed i.e. we do not need to specify the data types and these are interpreted at run-time only.
  • Cython is a compiled programming language like C and C++. It is statically typed i.e. we specify the data types in the code itself and no additional cost is incurred at run-time.
  • Best part of Cython is that instead of writing an entire code in C or C++ and then importing it, Cython requires minimal changes to existing Python code but with lots of boost in speed.
  • We can also import codes written entirely in C or C++ using Cython similar to RCpp.
  • Loops are inherently slow in interpreted language, which is mitigated by moving entire blocks of code involving loops to C or C++.

Without getting into the internals of Cython language, I would be demonstrating how to write Python codes in Cython and analyze and compare the run-times.

The first example that I am going to show is generating primes upto some N using the Sieve of Eratosthenes approach. The sieve algorithm works by eliminating all multiples of a prime number in each iteration, i.e. first remove all multiples of 2, then remove multiples of 3, 5, 7 and so on.

Observe that after we are done eliminating all multiples of a prime pi, the next number to consider is the next prime pi+1 because all composite numbers between pi and pi+1 are already eliminated by p0, p1, ... and so on. Also the smallest remaining multiple for pi is (pi)2 because all multiples of pi less than (pi)2 are eliminated.

Create a file named "sieve_python.py" and add the following Python code to it:

import numpy as np

def sieve(n):
    arr = np.empty(n+1, dtype=np.int32)
    arr.fill(1)
    arr[0], arr[1] = 0, 0

    sqrt_n = int(np.sqrt(n))
    for i in range(2, sqrt_n+1):
        if arr[i] == 1:
            j = i**2
            while j <= n:
                arr[j] = 0
                j += i

    return np.nonzero(arr)

Note that the above code is not optimized because we are not 'eliminating' the composites but just setting their bit value to 0. Alternately we can use a doubly linked list which allows us to delete nodes corresponding to composites in O(1) time complexity. But for comparison purpose at this moment, lets say that we use this version.

We need to benchmark the time taken with the plain Python version. On a Jupyter or IPython notebook, run the following code:

from sieve_python import sieve
%timeit sieve(5000000)

We get the following output:

1 loop, best of 3: 3.15 s per loop

Thus, the plain Python version takes around 3.15 s to obtain all prime numbers less than N=5 million.

Let's create the Cython version for the same code. Create a file named 'sieve_cython.pyx' in the same folder.

import numpy as np

def sieve(int n):
    cdef int[:] arr = np.empty(n+1, dtype=np.int32)
    cdef int sqrt_n, i, j
    
    arr[:] = 1
    arr[0], arr[1] = 0, 0

    sqrt_n = int(np.sqrt(n))
    for i in range(2, sqrt_n+1):
        if arr[i] == 1:
            j = i**2
            while j <= n:
                arr[j] = 0
                j += i

    return np.nonzero(arr)

The line 'cdef int[:] arr =np.empty(n+1, dtype=np.int32)' creates a typed memory-view of an empty NumPy array. This is something special to Cython, where NumPy arrays as well as C arrays can be casted into Cython memory-views.

The 'dtype' for the NumPy array is set to 'np.int32' because in C integer is 32 bit, 'np.int' is same as 'np.int64' i.e. 64 bit integer which is 'long' in C.

Since Cython is compiled programming language, we need to build and compile the above Cython code. Create a "setup.py" file in the same location:

from distutils.core import setup
from Cython.Build import cythonize

setup(name="sieve_cython", ext_modules=cythonize('sieve_cython.pyx'),)

Run the following command on the command line to build the Cython file:

python setup.py build_ext --inplace

This will generate a C file "sieve_cython.c" which is the C translation of the Cython code and a "build" folder containing the python distribution package. 

Now we can run the Cython version of the sieve method in the same way as a Python function:

from sieve_cython import sieve
%timeit sieve(5000000)

With Cython code for sieve algorithm we obtain the following timing:

10 loops, best of 3: 62.1 ms per loop

Just 62.1 ms as compared to 3.15 s with naive Python version, i.e. an improvement of around 50x. The changes to the original Python is very minimal, i.e. we only added the data types to the variables

We can check, which part of the Cython code uses pure Python and which part uses C, by running the following command:

cython sieve_cython.pyx -a

Which will generate a .html file as follows:

Cython HTML File

The lines marked in bright yellow are lines that uses pure Python, whereas lines without any yellow background uses pure C. The brighter the yellow background, the more 'Pythonic' it is.

This gives us "hint" that where we need to focus on in order to improve the performance of our code. But we notice that most of the yellow bands are around NumPy method calls, which is already highly optimized, thus it might not be further possible to improve the performance of the Cython code. 

Without giving up hope, we try to use mostly pure C syntax:

from libc.stdlib cimport malloc, free
from libc.math cimport sqrt

def sieve(int n):
    cdef int *arr = <int *> malloc((n+1) * sizeof(int))
    cdef int sqrt_n, i, j
    
    for i in range(2, n+1):
        arr[i] = 1

    sqrt_n = int(sqrt(n))
    for i in range(2, sqrt_n+1):
        if arr[i] == 1:
            j = i**2
            while j <= n:
                arr[j] = 0
                j += i

    cdef list out = [i for i in range(2, n+1) if arr[i] == 1]
    free(arr)
    return out

We define the array 'arr' to be a C array created using 'malloc' command and the memory is released using the 'free' command.

We have removed all references of NumPy from the code. Instead of using the square root method from numpy we are using C's built in 'sqrt' method.

The generated HTML file looks like:

Cython pure C syntax HTML File

Since we need to return a Cython variable as output, the variable 'out' needs to be a Cython variable.

We just made few changes to 'setup.py' file in-order to incorporate fast-math libraries for square root computation:

from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext

ext_modules=[ Extension("sieve_cython",
              ["sieve_cython.pyx"],
              libraries=["m"],
              extra_compile_args = ["-ffast-math"])]

setup(
  name = "sieve_cython",
  cmdclass = {"build_ext": build_ext},
  ext_modules = ext_modules)

Then we build the code and check the timings for the new code with 'almost' pure C syntax:

from sieve_cython import sieve
%timeit sieve(5000000)

The output is:

10 loops, best of 3: 56.5 ms per loop

Thus now we are able to do slightly better than the previous Cython code due to the usage of pure C syntaxes. The improvement is almost 56x from the naive Python implementation.

Additional Note:

For those who are interested in the double linked list implementation for the sieve algorithm, a Cython implementation is as follows:

from libc.stdlib cimport malloc, free
from libc.math cimport sqrt

cdef struct Node:
    int index
    Node *prev
    Node *next
    
ctypedef Node* myNode

def sieve(int n):
    cdef int sqrt_n, i, j
    cdef myNode curr_node, del_node
    
    cdef myNode *node_ref = <myNode*> malloc((n+1) * sizeof(myNode))
    
    node_ref[0], node_ref[1] = NULL, NULL
    
    for i in range(2, n+1):
        node_ref[i] = <myNode> malloc(sizeof(Node))
        node_ref[i].index = i
        node_ref[i].prev, node_ref[i].next = NULL, NULL
        
        if i > 2:
            curr_node.next = node_ref[i]
            node_ref[i].prev = curr_node
            
        curr_node = node_ref[i]
    
    sqrt_n, curr_node = int(sqrt(n)), node_ref[2]
    
    while curr_node is not NULL and curr_node.index <= sqrt_n:
        i = curr_node.index
        j = i**2
        while j <= n:
            if node_ref[j] is not NULL:
                del_node = node_ref[j]

                if del_node.prev is not NULL:
                    del_node.prev.next = del_node.next
                if del_node.next is not NULL:
                    del_node.next.prev = del_node.prev
                node_ref[j] = NULL
            j += i
            
        curr_node = curr_node.next

    cdef list out = [i for i in range(2, n+1) if node_ref[i] is not NULL]
    free(node_ref)
    return out

First we create a doubly linked list and keep a reference 'node_ref' to point to a node in the linked list. The 'node_ref' pointer is used to access the node to be deleted in O(1) time complexity, since unlike arrays we cannot access a node in linked list using an index. Also deletion in a double linked list is O(1) once we access the node to be deleted.

Although the linked list approach should have a better asymptotic run-time than the array implementation because the prime gap increases square of logarithmically with the current prime number:

gi = pi+1 - pi, then gi = O((log pi)2)

But due to the additional overhead of creating the linked list and changing references of previous and next pointers of the nodes, the overall time taken comes out to be quite large (650 ms) as compared to the simple array implementation (56.5 ms).

But the sieve is actually a reusable data structure, i.e. using a large enough value of N, once we find out all primes less than or equal to N, we can use the result to determine all primes less than for some M < N, which is a more practical use case.

The Cython codes are shared in my Git repository.

External Resources:

Categories: MACHINE LEARNING, PROBLEM SOLVING

Tags: , , , , ,