Rowan Marshall

Automatic memoization of function calls in Python

The Fibonacci sequence is the integer sequence where each number is the sum of the previous two numbers, starting with zero and one:

0, 1, 1, 2, 3, 5, 8, 13, 21, 34

Writing a function to calculate an arbitrary fibonacci number is a common problem in interviews and tests, that has an elegant if inefficient way of being solved via recursion.

Here’s a simple solution in Python:

def fib(n):
    if n < 2:
        return n
        return fib(n-1) + fib(n-2)

Given an integer n, this function will calculate the n’th number in the fibonacci sequence. If we try benchmarking this using the ever-useful timeout tool, using a value of 28:

➜ python3 -m timeit -s 'import fib' 'fib.fib(28)'
2 loops, best of 5: 109 msec per loop

Reasonably fast. But if we up this to 30

➜ python3 -m timeit -s 'import fib' 'fib.fib(30)'
1 loop, best of 5: 291 msec per loop

A small increase almost tripled the algorithm’s runtime. If we increase this much more it will get exponentially slower to calculate, this being an O(2^n) algorithm. Why is this so inefficient?

To solve this, we need to look at how the function is being called. Here’s a diagram stolen from Stack Overflow showing the recursive tree of fib(n) calls being made:

25 seperate calls altogether, with 18 of those being repeats. The percentage of repeated calls gets much worse as n increases, meaning a lot of work is being done unnecessarily.

How can we fix this, without resorting to changing our algorithm?

Introducing the LRU Cache decorator

Inside the functools standard library package, there’s a decorator called @lru_cache. This decorator records the arguments a function is called with and the value the function returned. If that function is called twice with the same arguments, the lru_cache will return the same value it got the first time, saving us the overhead of calling it twice. Since our issue is repeated function calls, we should expect a significant speedup.

Lets try timing that again after adding the decorator:

from functools import lru_cache

def fib(n):
    if n == 0:
        return 0
    elif n == 1:
        return 1
        return fib(n-1) + fib(n-2)
➜ python3 -m timeit -s 'import fib' 'fib.fib(30)'
2000000 loops, best of 5: 105 nsec per loop

That sped up our algorithm by about a million times. Not bad for one line of code.