When learning about recursion, a common example is the Fibonacci sequence. In simple terms, we say that the nth Fibonacci number is equal to the sum on the (n-1)th Fibonacci number and the (n-2)th Fibonacci number. By adding in the rule that the 0th and 1st Fibonacci numbers are 0 and 1 respectively, it’s possible to determine the value of any number in the series. This is pretty simple to define:
def fib(n:Long):Long = n match { case 0 => 0 case 1 => 1 case _ => fib(n - 1) + fib(n - 2) }
This is a really beautiful function because it purely describes the rules above. If n is 0 or 1, we exit the function otherwise we recursively calculate the (n-1)th and (n-2)th Fibonacci numbers. Unfortunately, pure functions are not always efficient in Scala.
Consider how the computer would work through fib(4)
. It must traverse a tree of functions until it gets to either fib(1)
or fib(0)
:
fib(4) / \ fib(3) fib(2) / \ / \ fib(2) fib(1) fib(1) fib(0) / \ fib(1) fib(0)
Since we told the computer the values of fib(1)
and fib(0)
, it can finally calculate the total by summing all of these together–3.
This tree quickly increases in complexity (at the rate of 2n) to the point where the function is almost unusable for even small numbers like 30 or 50. While most Scala programmers seem to agree that a functional style is preferable to an imperative style, the language itself allows us to do either. Should we sacrifice beauty for practicality? Let’s consider an imperative approach.
def fib() = { var n = 4 var a = 0 var b = 1 var c = 1 while(n > 0) { a = b b = c c = a + b n = n - 1 } a }
This code is a bit harder to read. To start out with, we have to declare our n value inside the function which feels wrong. Then we seed a few variables with the start of the Fibonacci sequence: 0, 1, 1. Here’s a chart to show how the computer processes this loop:
n | a | b | c |
---|---|---|---|
[seed] | 0 | 1 | 1 |
4 | 1 | 1 | 2 |
3 | 1 | 2 | 3 |
2 | 2 | 3 | 5 |
1 | 3 | 5 | 8 |
Notice how in each iteration of the loop the old value of a drops off? This saves both time and memory. It saves time because there are fewer iterations of the loop to run through (just 4 instead of 8), and it saves memory because we are only keeping track of a, b, c, and n instead of all the branches of the tree in the previous function.
So while we now have an efficient method, it feels a bit dirty and is not easy to understand simply by looking at the code. Is there a way to balance intuitiveness and efficiency? In fact, we can take advantage of tail recursion in Scala.
def fib(n:Long):Long = { def fibHelper(n:Long, a:Long, b:Long):Long = { if(n == 0) a else fibHelper(n - 1, b, a + b) } return fibHelper(n, 0, 1) }
This function is “tail recursive” because it is recursive (fibHelper
calls itself), but the computer only needs to keep track of the result of the previous call (the “tail”). Just like we saw in the imperative approach, this function “forgets” about the previous a value. By applying a tail recursive approach, we get the performance of the imperative method but with a relatively “pure” function. The tail recursive function does not resort to modifying variables nor does it require n to be defined inside the method.
Here’s a performance comparison for calculating Fibonacci 35:
Type | Stack Memory (bytes) | Time (ms) |
---|---|---|
Naive Recursion | 78892 | 78.8502 |
Imperative | 5056 | 0.0030 |
Tail Recursion | 5248 | 0.0037 |
Tail recursion is so preferable that Scala will try to optimize them automatically when compiling to bytecode. On my version (2.10.4), this function was automatically optimized. However, you can provide a compiler hint with the
@tailrec
annotation right above the fib
method. This will force the compiler to optimize the function accordingly (and will throw a compile-time error if the function is not truly tail-recursive).
Much of this was inspired by the discussion of recursion in Structure and Interpretation of Computer Programs. For more details, I highly recommend this book.