Andrzej Pragacz

Tail recursion in Python, part 1: trampolines

In this blog post I will try to explain what tail recursion is and show how tail call optimization can be achieved in Python using trampolines.

Python is not a functional programming language, but allows you to write the programs in functional way.

Functional programming is based on these principles:

  • Functions are first class citizens - you can define functions and call them, but more importantly you can pass the functions as parameters to another function. Also you can return a function as a result of calling another function.
  • Referential transparency & immutable data structures - functions have no "side-effects" e.g. calling given function many times with the same parameters will give the same results. Also, there is no way to re-assign variable to new value.
  • Recursion is allowed, e.g. function can call itself.

The recursion is required because with referential transparency you can't use loops.

The problem with recursion is that each function call requires to store some data which is required to restore the state when the function call finishes. This is problematic when you want do a lot of nested calls, because it may lead to infamous stack overflow.

To counteract this, many functional languages (like Haskell or OCaml) can transform function written as tail-recursive to a function internally running a loop. This process is called tail call optimization.

The function is tail-recursive only when it:

  1. returns direct result of recursive call to itself
  2. does not do any recursive call besides point 1

This causes the call graph (calling context tree) of recursive calls to be linear. And because we return the direct result of recursive call we don't need in theory to store the state of function when we doing a recursive call, because when the call finishes we don't need to restore the state, we just return the computed value returned by called function.

Therefore we can simulate performing these nested recursive calls as a loop, where function parameters become variables, changed in each iteration.

The example

Let's try some example: We will try to calculate the Fibonacci sequence, defined as below:

fn={0if n=01if n=1fn1+fn2if n2f_n = \begin{cases} 0 &\text{if } n=0 \\ 1 &\text{if } n=1 \\ f_{n-1} + f_{n-2} &\text{if }n \ge 2 \end{cases}

The function calculating the sequence, written in Python looks as follows:

def fib(n):
    assert n >= 0
    if n == 0:
        return 0
    elif n == 1:
        return 1
    else:
        return fib(n - 1) + fib(n - 2)

As we can see, this function is not tail-recursive, because we don't return direct result of recursive call of fib but result of sum of two recursive calls. So even if Python supported tail call optimization out-of-the-box it wouldn't be tail call optimized.

Let's test it with some input values:

>>> [fib(n) for n in range(10)]
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]

>>> fib(10000)
...
RecursionError: maximum recursion depth exceeded in comparison

Ooops! It seems that we stumbled upon the previously mentioned stack overflow. In other words: each nested function call needs to store some data so the state of the program can be properly restored when given function finishes. The space which stores this kind of information is limited, so there is a limited depth of function calls.

Let's try to rewrite the function in a tail-recursive fashion. To do this, we need to re-think how to calculate next Fibonacci number. We can do that by tracking the computation of two consecutive Fibonacci numbers.

a0=f0=0b0=f1=1a1=f1=b0b1=f2=f1+f0=b0+a0a2=f2=b1b2=f3=f2+f1=b1+a1......\begin{array}{ll} a_0 = f_0 = 0 & b_0 = f_1 = 1 \\ a_1 = f_1 = b_0 & b_1 = f_2 = f_1 + f_0 = b_0 + a_0 \\ a_2 = f_2 = b_1 & b_2 = f_3 = f_2 + f_1 = b_1 + a_1 \\ ... & ... \end{array}

We can write this in Python as follows:

def fib(n):

    def helper(k, a, b):
        if k == 0:
            return a
        else:
            a2 = b
            b2 = a + b
            return helper(k - 1, a2, b2)

    assert n >= 0
    return helper(n, 0, 1)

With that, we actually optimized the function. Now fib(n) time complexity is linear (O(n)O(n) in Big O notation) instead of being exponential (O(2n)O(2^n)).

Also, the internal helper function is actually in tail-recursive form, because it returns the direct result of recursive call - helper(k - 1, a2, b2). The fib functions only primes the helper function with start parameters.

However, we still didn't solve the stack overflow problem:

>>> fib(10000)
...
RecursionError: maximum recursion depth exceeded in comparison

Trampolines

Let's try to introduce new concept: trampolines.

The main idea is that we want to "delay" the execution by wrapping it into a lambda expression, which can be moved "upwards" the call stack and then executed there.

We will use a decorator to describe this:

import functools

def trampoline_tailrec(f):

    @functools.wraps(f)
    def decorator(*args, **kwargs):
        result = f(*args, **kwargs)
        while callable(result):
            result = result()
        return result

    return decorator

We need to do a slight change in our fib function:

def fib(n):

    @trampoline_tailrec
    def helper(k, a, b):
        if k == 0:
            return a
        else:
            a2 = b
            b2 = a + b
            return lambda: helper(k - 1, a2, b2)

    assert n >= 0
    return helper(n, 0, 1)

Let's test the new version.

>>> [fib(n) for n in range(10)]
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
>>> fib(10000)
...
RecursionError: maximum recursion depth exceeded in comparison

Hmmm. We still have problem with stack overflow. What happened here?

The problem is that the helper(k - 1, a2, b2) call in the lambda expression refers to the decorated function, not the original function. Therefore instead of one while loop, we get multiple while loops, each performing just one step.

Let's fix this by not using the decorator syntax:

def fib(n):

    def helper(k, a, b):
        if k == 0:
            return a
        else:
            a2 = b
            b2 = a + b
            return lambda: helper(k - 1, a2, b2)

    assert n >= 0
    decorated_helper = trampoline_tailrec(helper)
    return decorated_helper(n, 0, 1)
>>> [fib(n) for n in range(10)]
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
>>> fib(10000)


Finally! but there are few problems with this:

  • We would like to use the decorator syntax.
  • We assume that if the function returns another callable then the callable is a trampoline. This may not always be true.

The refactor

Let's define a separate Trampoline type, to distinguish trampoline objects it from other callables. There could be various ways to to that, We will use namedtuple to avoid boilerplate code.

from collections import namedtuple

Trampoline = namedtuple('Trampoline', ['lazy_expr'])

We will also wrap the function with a callable object instead of usual function. This will allow to define .trampoline() method which will generate Trampoline object for given function arguments.

from collections.abc import Callable


class TrampolineCallable(Callable):

    def __init__(self, f):
        self.f = f

    def __call__(self, *args, **kwargs):
        result = self.f(*args, **kwargs)
        while isinstance(result, Trampoline):
            result = result.lazy_expr()
        return result

    def trampoline(self, *args, **kwargs):
        return Trampoline(lambda: self.f(*args, **kwargs))

def trampoline_tailrec(f):
    return functools.wraps(f)(TrampolineCallable(f));

Let's try the new version of fib function:

def fib(n):

    @trampoline_tailrec
    def helper(k, a, b):
        if k == 0:
            return a
        else:
            a2 = b
            b2 = a + b
            return helper.trampoline(k - 1, a2, b2)

    assert n >= 0
    return helper(n, 0, 1)
>>> [fib(n) for n in range(10)]
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
>>> fib(10000)


Quite nice. But we still have the problem that we needed to modify the helper function to return explicit trampoline. In the next post we will try to avoid that problem.

You can experiment with the examples provided above in this Jupyter notebook.