The Fibonacci sequence is the integer sequence where each number is the sum of the previous two numbers, starting with zero and one:
0, 1, 1, 2, 3, 5, 8, 13, 21, 34
Writing a function to calculate an arbitrary fibonacci number is a common problem in interviews and tests, that has an elegant if inefficient way of being solved via recursion.
Here’s a simple solution in Python:
def fib(n): if n < 2: return n else: return fib(n-1) + fib(n-2)
Given an integer n, this function will calculate the n’th number in the fibonacci sequence. If we try benchmarking this using the ever-useful
timeout tool, using a value of 28:
➜ python3 -m timeit -s 'import fib' 'fib.fib(28)' 2 loops, best of 5: 109 msec per loop
Reasonably fast. But if we up this to 30…
➜ python3 -m timeit -s 'import fib' 'fib.fib(30)' 1 loop, best of 5: 291 msec per loop
A small increase almost tripled the algorithm’s runtime. If we increase this much more it will get exponentially slower to calculate, this being an
O(2^n) algorithm. Why is this so inefficient?
To solve this, we need to look at how the function is being called. Here’s a diagram stolen from Stack Overflow showing the recursive tree of fib(n) calls being made:
25 seperate calls altogether, with 18 of those being repeats. The percentage of repeated calls gets much worse as n increases, meaning a lot of work is being done unnecessarily.
How can we fix this, without resorting to changing our algorithm?
Introducing the LRU Cache decorator
functools standard library package, there’s a decorator called
@lru_cache. This decorator records the arguments a function is called with and the value the function returned. If that function is called twice with the same arguments, the lru_cache will return the same value it got the first time, saving us the overhead of calling it twice. Since our issue is repeated function calls, we should expect a significant speedup.
Lets try timing that again after adding the decorator:
from functools import lru_cache @lru_cache def fib(n): if n == 0: return 0 elif n == 1: return 1 else: return fib(n-1) + fib(n-2)
➜ python3 -m timeit -s 'import fib' 'fib.fib(30)' 2000000 loops, best of 5: 105 nsec per loop
That sped up our algorithm by about a million times. Not bad for one line of code.