Tail recursion in Python, part 1: trampolines
In this blog post I will try to explain what tail recursion is and show how tail call optimization can be achieved in Python using trampolines.
Python is not a functional programming language, but allows you to write the programs in functional way.
Functional programming is based on these principles:
- Functions are first class citizens - you can define functions and call them, but more importantly you can pass the functions as parameters to another function. Also you can return a function as a result of calling another function.
- Referential transparency & immutable data structures - functions have no "side-effects" e.g. calling given function many times with the same parameters will give the same results. Also, there is no way to re-assign variable to new value.
- Recursion is allowed, e.g. function can call itself.
The recursion is required because with referential transparency you can't use loops.
The problem with recursion is that each function call requires to store some data which is required to restore the state when the function call finishes. This is problematic when you want do a lot of nested calls, because it may lead to infamous stack overflow.
To counteract this, many functional languages (like Haskell or OCaml) can transform function written as tail-recursive to a function internally running a loop. This process is called tail call optimization.
The function is tail-recursive only when it:
- returns direct result of recursive call to itself
- does not do any recursive call besides point 1
This causes the call graph (calling context tree) of recursive calls to be linear. And because we return the direct result of recursive call we don't need in theory to store the state of function when we doing a recursive call, because when the call finishes we don't need to restore the state, we just return the computed value returned by called function.
Therefore we can simulate performing these nested recursive calls as a loop, where function parameters become variables, changed in each iteration.
The example
Let's try some example: We will try to calculate the Fibonacci sequence, defined as below:
The function calculating the sequence, written in Python looks as follows:
def fib(n):
assert n >= 0
if n == 0:
return 0
elif n == 1:
return 1
else:
return fib(n - 1) + fib(n - 2)
As we can see, this function is not tail-recursive, because we don't return
direct result of recursive call of fib
but result of sum of
two recursive calls. So even if Python supported
tail call optimization out-of-the-box it wouldn't be tail call optimized.
Let's test it with some input values:
>>> [fib(n) for n in range(10)]
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
>>> fib(10000)
...
RecursionError: maximum recursion depth exceeded in comparison
Ooops! It seems that we stumbled upon the previously mentioned stack overflow. In other words: each nested function call needs to store some data so the state of the program can be properly restored when given function finishes. The space which stores this kind of information is limited, so there is a limited depth of function calls.
Let's try to rewrite the function in a tail-recursive fashion. To do this, we need to re-think how to calculate next Fibonacci number. We can do that by tracking the computation of two consecutive Fibonacci numbers.
We can write this in Python as follows:
def fib(n):
def helper(k, a, b):
if k == 0:
return a
else:
a2 = b
b2 = a + b
return helper(k - 1, a2, b2)
assert n >= 0
return helper(n, 0, 1)
With that, we actually optimized the function. Now fib(n)
time complexity is linear
( in Big O notation)
instead of being exponential ().
Also, the internal helper
function is actually in tail-recursive form,
because it returns the direct result of
recursive call - helper(k - 1, a2, b2)
.
The fib
functions only primes the helper
function with start parameters.
However, we still didn't solve the stack overflow problem:
>>> fib(10000)
...
RecursionError: maximum recursion depth exceeded in comparison
Trampolines
Let's try to introduce new concept: trampolines.
The main idea is that we want to "delay" the execution by wrapping it into a lambda expression, which can be moved "upwards" the call stack and then executed there.
We will use a decorator to describe this:
import functools
def trampoline_tailrec(f):
@functools.wraps(f)
def decorator(*args, **kwargs):
result = f(*args, **kwargs)
while callable(result):
result = result()
return result
return decorator
We need to do a slight change in our fib
function:
def fib(n):
@trampoline_tailrec
def helper(k, a, b):
if k == 0:
return a
else:
a2 = b
b2 = a + b
return lambda: helper(k - 1, a2, b2)
assert n >= 0
return helper(n, 0, 1)
Let's test the new version.
>>> [fib(n) for n in range(10)]
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
>>> fib(10000)
...
RecursionError: maximum recursion depth exceeded in comparison
Hmmm. We still have problem with stack overflow. What happened here?
The problem is that the helper(k - 1, a2, b2)
call in the lambda expression
refers to the decorated function, not the original function.
Therefore instead of one while
loop, we get multiple while
loops,
each performing just one step.
Let's fix this by not using the decorator syntax:
def fib(n):
def helper(k, a, b):
if k == 0:
return a
else:
a2 = b
b2 = a + b
return lambda: helper(k - 1, a2, b2)
assert n >= 0
decorated_helper = trampoline_tailrec(helper)
return decorated_helper(n, 0, 1)
>>> [fib(n) for n in range(10)]
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
>>> fib(10000)
33644764876431783266621612005107543310302148460680063906564769974680081442166662368155595513633734025582065332680836159373734790483865268263040892463056431887354544369559827491606602099884183933864652731300088830269235673613135117579297437854413752130520504347701602264758318906527890855154366159582987279682987510631200575428783453215515103870818298969791613127856265033195487140214287532698187962046936097879900350962302291026368131493195275630227837628441540360584402572114334961180023091208287046088923962328835461505776583271252546093591128203925285393434620904245248929403901706233888991085841065183173360437470737908552631764325733993712871937587746897479926305837065742830161637408969178426378624212835258112820516370298089332099905707920064367426202389783111470054074998459250360633560933883831923386783056136435351892133279732908133732642652633989763922723407882928177953580570993691049175470808931841056146322338217465637321248226383092103297701648054726243842374862411453093812206564914032751086643394517512161526545361333111314042436854805106765843493523836959653428071768775328348234345557366719731392746273629108210679280784718035329131176778924659089938635459327894523777674406192240337638674004021330343297496902028328145933418826817683893072003634795623117103101291953169794607632737589253530772552375943788434504067715555779056450443016640119462580972216729758615026968443146952034614932291105970676243268515992834709891284706740862008587135016260312071903172086094081298321581077282076353186624611278245537208532365305775956430072517744315051539600905168603220349163222640885248852433158051534849622434848299380905070483482449327453732624567755879089187190803662058009594743150052402532709746995318770724376825907419939632265984147498193609285223945039707165443156421328157688908058783183404917434556270520223564846495196112460268313970975069382648706613264507665074611512677522748621598642530711298441182622661057163515069260029861704945425047491378115154139941550671256271197133252763631939606902895650288268608362241082050562430701794976171121233066073310059947366875
Finally! but there are few problems with this:
- We would like to use the decorator syntax.
- We assume that if the function returns another callable then the callable is a trampoline. This may not always be true.
The refactor
Let's define a separate Trampoline
type, to distinguish trampoline objects
it from other callables. There could be various ways to to that,
We will use namedtuple
to avoid boilerplate code.
from collections import namedtuple
Trampoline = namedtuple('Trampoline', ['lazy_expr'])
We will also wrap the function with a callable object instead of
usual function. This will allow to define .trampoline()
method
which will generate Trampoline
object for given function arguments.
from collections.abc import Callable
class TrampolineCallable(Callable):
def __init__(self, f):
self.f = f
def __call__(self, *args, **kwargs):
result = self.f(*args, **kwargs)
while isinstance(result, Trampoline):
result = result.lazy_expr()
return result
def trampoline(self, *args, **kwargs):
return Trampoline(lambda: self.f(*args, **kwargs))
def trampoline_tailrec(f):
return functools.wraps(f)(TrampolineCallable(f));
Let's try the new version of fib
function:
def fib(n):
@trampoline_tailrec
def helper(k, a, b):
if k == 0:
return a
else:
a2 = b
b2 = a + b
return helper.trampoline(k - 1, a2, b2)
assert n >= 0
return helper(n, 0, 1)
>>> [fib(n) for n in range(10)]
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
>>> fib(10000)
33644764876431783266621612005107543310302148460680063906564769974680081442166662368155595513633734025582065332680836159373734790483865268263040892463056431887354544369559827491606602099884183933864652731300088830269235673613135117579297437854413752130520504347701602264758318906527890855154366159582987279682987510631200575428783453215515103870818298969791613127856265033195487140214287532698187962046936097879900350962302291026368131493195275630227837628441540360584402572114334961180023091208287046088923962328835461505776583271252546093591128203925285393434620904245248929403901706233888991085841065183173360437470737908552631764325733993712871937587746897479926305837065742830161637408969178426378624212835258112820516370298089332099905707920064367426202389783111470054074998459250360633560933883831923386783056136435351892133279732908133732642652633989763922723407882928177953580570993691049175470808931841056146322338217465637321248226383092103297701648054726243842374862411453093812206564914032751086643394517512161526545361333111314042436854805106765843493523836959653428071768775328348234345557366719731392746273629108210679280784718035329131176778924659089938635459327894523777674406192240337638674004021330343297496902028328145933418826817683893072003634795623117103101291953169794607632737589253530772552375943788434504067715555779056450443016640119462580972216729758615026968443146952034614932291105970676243268515992834709891284706740862008587135016260312071903172086094081298321581077282076353186624611278245537208532365305775956430072517744315051539600905168603220349163222640885248852433158051534849622434848299380905070483482449327453732624567755879089187190803662058009594743150052402532709746995318770724376825907419939632265984147498193609285223945039707165443156421328157688908058783183404917434556270520223564846495196112460268313970975069382648706613264507665074611512677522748621598642530711298441182622661057163515069260029861704945425047491378115154139941550671256271197133252763631939606902895650288268608362241082050562430701794976171121233066073310059947366875
Quite nice. But we still have the problem that we needed to modify
the helper
function to return explicit trampoline.
In the next post we will try to avoid that problem.
You can experiment with the examples provided above in this Jupyter notebook.