Limitations of numba

When numba works, it works really well and is very easy to use.

However, the type of code that numba can accelerate is very limited. It works well when working directly with numeric data within simple loops, but struggles as soon as you are calling functions or interacting with Python objects. This is because numba cannot get around the problem in Python, that every function call or object method call is indirect, and involves interacting with the Python Virtual Machine.

For example, let’s add a progress bar to our square root function;

import math
import numpy as np
import numba
import tqdm

@numba.jit()
def calculate_roots(numbers):
    num_vals = len(numbers)
    result = np.zeros(num_vals, "f")

    for i in tqdm.tqdm(range(0, num_vals)):
        result[i] = math.sqrt(numbers[i])

    return result

numbers = np.random.rand(10000000)
result = calculate_roots(numbers)
100%|█████████████████████████████████████████████████████████████████████| 10000000/10000000 [00:02<00:00, 4361047.79it/s]

(4.4 million square roots per second)

Timing this shows that it is much slower;

timeit(calculate_roots(numbers))
100%|██████████████████████████| 10000000/10000000 [00:02<00:00, 4088038.31it/s]
100%|██████████████████████████| 10000000/10000000 [00:02<00:00, 4042940.75it/s]
100%|██████████████████████████| 10000000/10000000 [00:02<00:00, 4154570.11it/s]
100%|██████████████████████████| 10000000/10000000 [00:02<00:00, 4008025.41it/s]
100%|██████████████████████████| 10000000/10000000 [00:02<00:00, 4013519.05it/s]
100%|██████████████████████████| 10000000/10000000 [00:02<00:00, 4003978.49it/s]
100%|██████████████████████████| 10000000/10000000 [00:02<00:00, 4194120.72it/s]
100%|██████████████████████████| 10000000/10000000 [00:02<00:00, 4180628.33it/s]
2.45 s ± 47.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

This is about 800 times slower than the serial numba code. You may also have, like me, seen lots of warnings printed to the screen;

  warnings.warn(errors.NumbaWarning(warn_msg,
/path/to/python3.8/site-packages/numba/core/object_mode_passes.py:161: NumbaDeprecationWarning:
Fall-back from the nopython compilation path to the object mode compilation path has been detected, this is deprecated behaviour.

For more information visit https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit

We get this warning because the tqdm progress bar is a normal Python object. numba does not work well with standard Python objects, which includes Python containers like lists or dictionaries. The function to be accelerated has to operate on only simple types (e.g. numbers like floats or integers), and data has to be held in arrays, such as a numpy array. The reason is that accessing any function on a Python object means calling the a Python function that is external to the @numba.jit() function that you are accelerating. This means moving out of your numba-compiled code, and back into Python code that is executed on the Python virtual machine.

You may ask how len and math.sqrt calls work, as these are also function calls? In these cases, numba has some special code that recognises that math.sqrt is really just a call to a built-in square root function, and so it replaces the Python math.sqrt with a built-in pre-compiled square root. It also recognises that len is simply looking up the size of the array, so it does that directly.

These kind of transformations mean that numba will only accelerate code where data is held in arrays (e.g. numpy arrays) and where the operations performed on that data can be mapped to standard operations (e.g. +, -, *, /) or built-in functions provided by numba.

(you can see a full list of supported Python features here)

In particular, the developers of numba have focussed on accelerating Python code that uses numpy. As such, using numpy is recommended, with lots of numpy functionality being available to be accelerated via numba.

(you can see a full list of supported numpy features here)

At the time of writing, numba does not support pandas. This means that it can be difficult to use numba to accelerate scripts that make heavy use of pandas.

see here for a more detailed answer to the pandas question

nopython mode

Printing a warning is helpful, but there are many times when you would prefer that the script should fail if it can’t be accelerated. Accelerated scripts run thousands of times faster than non-accelerated scripts. This means that if the accelerated script take seconds, the non-accelerated script would take hours (there are 3600 seconds per hour). In this case, it is better that a script exits quickly with an error, than wastes hours of compute time running non-accelerated.

You can tell numba to exit if it is unable to fully accelerate a function. You do this by adding nopython=True to the decorator, e.g.

import math
import numpy as np
import numba
import tqdm

@numba.jit(nopython=True)
def calculate_roots(numbers):
    num_vals = len(numbers)
    result = np.zeros(num_vals, "f")

    for i in tqdm.tqdm(range(0, num_vals)):
        result[i] = math.sqrt(numbers[i])

    return result

numbers = np.random.rand(10000000)
result = calculate_roots(numbers)
---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
Input In [14], in <cell line: 17>()
     14     return result
     16 numbers = np.random.rand(10000000)
---> 17 result = calculate_roots(numbers)

File ~/conda_arm64/lib/python3.8/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
    464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
    465                f"by the following argument(s):\n{args_str}\n")
    466         e.patch_message(msg)
--> 468     error_rewrite(e, 'typing')
    469 except errors.UnsupportedError as e:
    470     # Something unsupported is present in the user code, add help info
    471     error_rewrite(e, 'unsupported_error')

File ~/conda_arm64/lib/python3.8/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
    407     raise e
    408 else:
--> 409     raise e.with_traceback(None)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)

Factoring out numba-supported code into functions

There are times where we do need to mix acceleratable and non-acceleratable code. For example, how could we accelerate our square-root code while keeping the progress bar?

In general, you do this by factoring out the numba-only part into a “worker” function, which is only called by an outer-function.

For example, here we factor out the loop into a numba-only part, which is called by an outer loop, which is connected to the progress bar.

@numba.jit(nopython=True, parallel=True)
def inner_calculate_roots(numbers, result, start, end):
    for i in numba.prange(start, end):
        result[i] = math.sqrt(numbers[i])


def calculate_roots(numbers):
    num_vals = len(numbers)
    result = np.zeros(num_vals, "f")
    nblocks = 10
    num_per_block = int(num_vals / nblocks)

    while nblocks*num_per_block < num_vals:
        num_per_block += 1

    for i in tqdm.tqdm(range(0, nblocks), unit_scale=num_per_block):
        start = i * num_per_block
        end = min(num_vals, (i+1)*num_per_block)
        inner_calculate_roots(numbers, result, start, end)

    return result

numbers = np.random.rand(10000000)
result = calculate_roots(numbers)
100%|███████████████████████| 10000000/10000000 [00:00<00:00, 1569372146.97it/s]

(1.6 billion(!) square roots per second)

timeit(calculate_roots(numbers))
10.1 ms ± 93.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Chunking up the loop into blocks has enabled us to get the speed of numba, while also retaining useful Python code, such as using a progress bar. It has also allowed us to parallelise the loop with numba.prange, while still retaining a progress bar with tqdm.tqdm.

This is useful, as progress bars provide a great tool for embedding the timing of your code into your script. Here, we can see that numba has accelerated this code from ~4.4 million iterations per second to ~1.6 billion iterations per second.

Conclusion

numba is extremely poweful and very simple to use when it works. You have to be careful to place everything into arrays, and be mindful of which code is numeric (e.g. operations and functions called on numeric data held in numpy arrays), and which code is Python object code. You have to do the work to separate Python object code and numeric code into, ideally, separate functions, and then add @numba.jit(nopython=True, cache=True) decorators to the numeric functions.

You can parallelise your numba code by adding parallel=True and switching to numba.prange for the loops that you want to run in parallel. If you have time, you can check out this bonus material that shows you how to best optimise parallel numba code.

Finally, numba has excellent documentation. We highly recommend that you read through this if you want to learn more about its features, e.g. how to call external C functions, how to automatically vectorise functions, and how to compile functions that will run on a GPU.

numba is an excellent tool, and a great first step in accelerating your Python scripts. It is truly amazing that, when it works, a simple @numba.jit() decorator can accelerate your script by thousands of times, thereby saving you time, and saving significant amounts of energy.

Bonus Exercise

Add a tqdm progress bar to the calculate_scores function in slow.py.

Answer to this exercise