Automatic memoization of function calls in Python

The Fibonacci sequence is the integer sequence where each number is the sum of the previous two numbers, starting with zero and one:

0, 1, 1, 2, 3, 5, 8, 13, 21, 34

Writing a function to calculate an arbitrary fibonacci number is a common problem in interviews and tests, that has an elegant if inefficient way of being solved via recursion.

Here’s a simple solution in Python:

def fib(n):
    if n < 2:
        return n
        return fib(n-1) + fib(n-2)

Given an integer n, this function will calculate the n’th number in the fibonacci sequence. If we try benchmarking this using the ever-useful timeout tool, using a value of 28:

➜ python3 -m timeit -s 'import fib' 'fib.fib(28)'
2 loops, best of 5: 109 msec per loop

Reasonably fast. But if we up this to 30

➜ python3 -m timeit -s 'import fib' 'fib.fib(30)'
1 loop, best of 5: 291 msec per loop

A small increase almost tripled the algorithm’s runtime. If we increase this much more it will get exponentially slower to calculate, this being an O(2^n) algorithm. Why is this so inefficient?

To solve this, we need to look at how the function is being called. Here’s a diagram stolen from Stack Overflow showing the recursive tree of fib(n) calls being made:

25 seperate calls altogether, with 18 of those being repeats. The percentage of repeated calls gets much worse as n increases, meaning a lot of work is being done unnecessarily.

How can we fix this, without resorting to changing our algorithm?

Introducing the LRU Cache decorator

Inside the functools standard library package, there’s a decorator called @lru_cache. This decorator records the arguments a function is called with and the value the function returned. If that function is called twice with the same arguments, the lru_cache will return the same value it got the first time, saving us the overhead of calling it twice. Since our issue is repeated function calls, we should expect a significant speedup.

Lets try timing that again after adding the decorator:

from functools import lru_cache

def fib(n):
    if n == 0:
        return 0
    elif n == 1:
        return 1
        return fib(n-1) + fib(n-2)
➜ python3 -m timeit -s 'import fib' 'fib.fib(30)'
2000000 loops, best of 5: 105 nsec per loop

That sped up our algorithm by about a million times. Not bad for one line of code.