【Scala】再帰処理を末尾最適化するための考え方をSICPに学ぶ

Scala関数型デザイン&プログラミングScala関数型プログラミングに慣れようとしている。最初の練習問題はn番目のフィボナッチ数を出力する関数を作るというもの。
はじめは以下のように書いた。

object Main {
  //@annotation.tailrec
  def fib(n: Int): Int = {
    n match {
      case m if m <= 0 => throw new IndexOutOfBoundsException("fibonacci index should be positive number")
      case 1 => 0
      case 2 => 1
      case _ => fib(n - 2) + fib(n - 1)
    }
  }

  def main(args: Array[String]): Unit = {
    println(fib(12))
  }
}

正常に動くことは動くが、 @annotation.tailrec をつけるとコンパイルエラーになる。 この関数が末尾再帰になっていないためだ。

末尾再帰とは

末尾再帰とは、Scala関数型プログラミングに限った概念ではない。Scala関数型デザイン&プログラミングの説明だけではいまいち理解できなかったが、昔読んだSICPの第一章にも出ているのを知り、久しぶりに読み返してみた。

mitpress.mit.edu

SICPでは、ある数の階乗 n! を求める再帰関数を以下のように書いている。

(define (factorial n)
  (if (= n 1)
      1
      (* n (factorial (- n 1)))))

schemeの読み方はここでは詳しく書かないが、 (define (factorial n) (<body>)) という形式で、引数 n を取る factorial という名前の関数を定義している。また、* は2つの引数の積を返す関数で、 (* a b) と書くと ab の積が返ってくる。

例えば factorial(6) のように実行すると、この関数の中では

  • (* (6 factorial(5)))
  • (* (6 *(5 factorial(4))))
  • (* (6 *(5 * (4 factorial(3))))) ...

のように計算が進んでいく。この関数は文法的には問題ないが、最終行で (* n (factorial (- n 1))))) という計算処理を行いながら自身を呼び出している。これにより、* の第一引数である n の値を、再帰でループする回数分だけスタックに保存しておかなければいけない。もしループする回数が巨大になると、スタックに入り切らずStackOverFlowが起こりうる。

巨大な回数分ループできるようにするために、SICPでは以下のように関数を書き換えている。

(define (factorial n)
  (fact-iter 1 1 n))

(define (fact-iter product counter max-count)
  (if (> counter max-count)
      product
      (fact-iter (* counter product)
                 (+ counter 1)
                 max-count)))

factorial 関数の中に fact-iter という別の関数を定義し、その fact-iter再帰関数になっている。fact-iterproduct, counter, max-count の3つの引数を取っている。
注目すべきは、 fact-iter が自身を呼び出す際、それ以外のデータを一時的に保存していないことだ。 countermax-count より大きくなるまでは、 fact-iter は引数を更新して自身を呼び直すことしかしていない。このため、 factorial(6) の計算は

のように進んでいく。fact-oterの呼び出し以外に一時的に保存されているデータはない。

重要なのは、factorial 関数の中にもうひとつ fact-iter という関数を定義していること。しかも fact-iter 関数はコードの書き方の上では再帰と言えるが、そのプロセス自体はforループやwhileループと同じ反復プロセスになっているということにある。最初の factorial 関数の実装では、計算はせずに再帰呼び出しを最後まで終え、その後計算が遅延して行われる。一方で2番目の例では、1回のループの中で1度ずつ計算が行われ、 product 引数の値が更新されていくのがわかる。

この2番目の例のような再帰を末尾再帰と言い、処理の最後のステップで、自身を呼び出す以外の計算処理を一切していないような再帰関数を指す。これにより、反復プロセスと同様に巨大な回数のループでもスタックを食い潰さずに処理できる。

fib関数を末尾最適化する

Scalaにおける末尾最適とは、再帰関数を末尾再帰であるように書くことでコンパイラがそれを検知し、whileループと同様になるようにコンパイルすることを言う。最初のfib関数の例では、自身を呼び出す処理が fib(n - 2) + fib(n - 1) となっている。これは末尾再帰とはいえない。
以下のように書き直すことで、末尾最適化することができる。

object Main {
  def fib(n: Int): Int = {
    @annotation.tailrec
    def sumPrevElems(i: Int, j: Int, count: Int): Int = {
      if (count == 1) i
      else sumPrevElems(j, i+j, count-1)
    }
    n match {
      case m if m <= 0 => throw new IndexOutOfBoundsException("fibonacci index should be positive number")
      case _ => sumPrevElems(0, 1, n)
    }
  }

  def main(args: Array[String]): Unit = {
    println(fib(1))
  }
}

fib 関数の中に sumPrevElems 関数を定義し、その中で連続する2つのフィボナッチ数を合計している。合計した値は sumPrevElems 関数を再帰呼び出しする際に引数に含めており、指定した回数ループしたら再帰をやめるようにしてある。
末尾再帰を実現させるには、最後にまとめて計算するのではなく、 再帰呼び出しが一度行われるたびに計算を行ってその値を渡していくことがテクニックの1つかもしれない。