Tail call recursion & optimization in Scala
Lets start with recursion and then dive to optimization.
Recursion:
Some times we wanted to repetitively execute some piece of code/compute something until a condition is met. We can achieve this,in both functional and imperative style of programming.
Many of us quite familiar with the following imperative style, using while loop:
def function1()={val args=List("first.java","second.scala","third.python")
var i = 0
var foundIt = falsewhile (i < args.length && !foundIt) {
if (args(i).endsWith(".scala")) foundIt = true
else i=i+1
}if (foundIt==false) i= -1
i
}
println(function1())
Explanation for the above piece of code:
We have a list of strings ie, args. We would like to repetitively validate if any of its elements ends with “.scala”. If so print the position of the string in the list else print “-1”.
Scala is both object oriented and functional programming language. To make this code functional (without any side effects/usage of vars), we can rewrite the same logic using recursive function.
val args=List("first.java","second.scala","third.python")def function1(position:Int):Int={if(position>=args.length) -1
else if(args(position).endsWith(".scala")) position
else function1(position+1) // recursive call}println(function1(0))
The important thing to note here is, the recursive call made above, makes function1, a tail call recursive function.
What is tail call recursion?
A function that calls itself repetitively until a certain condition is met, at the last step is known as tail call recursive function. If you observe the above piece of code, recursive call is placed at the last step making it tail call recursive.
What difference does tail recursion causes?
The answer is performance. Whenever tail recursion is defined, Scala compiler performs tail call optimization to improvise the execution of recursive calls.
How the optimization takes place?What happens without tail call optimization?
Whenever tail call recursion is recognized, Scala compiler optimizes it.
Generally without tail call optimization, for each invocation of the recursive function, a new stack frame (each stack frame is the data associated with each function call) is allocated.
But with tail call optimization, only single stack frame is allocated for recursive function calls, like how it gets allocated with imperative style of writing using while loop.
To understand the stack frame allocation for optimized and non-optimized recursion, we need to view the stack trace. Stack trace is report of the active stack frames at a certain point of time in execution. It has the list of methods invoked from the start until the point of time. Stack trace is printed onto the screen for debugging whenever an exception is thrown. It helps the programmer to determine the flow of execution and find which function call, exactly causes the exception.
Lets try printing the stack trace for both optimized and non-optimized recursive function calls and obseve the difference. The above code is tweaked to have exception thrown when the searched element is not found in the list. It makes easier to view the stack trace printed onto the screen.
Same code with exception thrown when .scala string doesn’t exist in the list:
val args=List("first.java","second.scal","third.python")def function1(position:Int):Int= {if(position>=args.length) throw new Exception("-1")
else if(args(position).endsWith(".scala")) position
else {function1(position+1)}}Output:
function1(0)java.lang.Exception: -1
at function1(<console>:2)
... 32 elided
As you can see despite recursive calls to the function, only single stack frame is allocated because Scala compiler recognizes recursive tail call and performs tail call optimization.
There are certain limits that restricts Scala compiler from recognizing tail call recursion. It results in no optimization ie,for each call a new stack frame is allocated.
The following are few of those limits,
- Indirect recursion call
- Function value that wraps the recursive function
- Recursive function call along with any of the operations performed
Here in this illustration, we make use of recursive function call along with arithmetic operation, to make Scala compiler not acknowledge tail call recursion, resulting in non-optimization.
val args=List("first.java","second.scal","third.python")
val function_literal= function1 _def function1(position:Int):Int= {if(position>=args.length) throw new Exception("-1")
else if(args(position).endsWith(".scala")) position
else { function_literal(position+1)+1 } // +1 operation is peformed along with the function call}Output:function_literal(0)java.lang.Exception: -1
at function1(<console>:2)
at function1(<console>:4)
at function1(<console>:4)
at function1(<console>:4)... 32 elided
Even-though recursive function call happens at the end (or tail), because of the arithmetic operation performed along with the call, Scala fails to recognize tail call recursion and doesn’t optimize the code. In each function call, it has to remember the incremented value and so creates exclusive stack frame every-time