재귀함수의 문제점
재귀함수는 이해하기 은근히 어렵다는 점 때문에 반복문에 비해 선호되지 않는 편이지만, 재귀적인 자료구조(트리 등)을 다루거나 함수형 프로그래밍을 할 때에는 필수적입니다. 하지만 현실적으로는 재귀함수 호출이 누적되다보면 스택 오버플로가 발생한다는 큰 문제점이 있습니다. 이걸 해결하기 위해 꼬리 재귀 최적화(tail-call optimization)와 같은 기법이 있지만, 컴파일러에 따라 지원하지 않기도 하고 모든 재귀함수에 적용할 수 있는 최적화도 아닙니다.
# 1 + 4 + 9 + ... + n ** 2
def sum_of_squares(n: int) -> int:
if n == 1:
return 1
return sum_of_squares(n - 1) + n ** 2
print(sum_of_squares(10000)) # 파이썬은 기본적으로 호출 깊이를 1000으로 제한하기 때문에 에러가 발생
# 게다가 파이썬은 꼬리 재귀 최적화를 지원하지 않음
그렇다면 어떻게 해야 재귀 호출 없이 재귀함수를 구현할 수 있을까요? 함수형 프로그래밍에서 트램폴린(trampoline)이라고 불리는 기법을 사용하면 재귀함수에서 재귀 호출을 효과적으로 없앨 수 있습니다.
재귀함수 안에서 재귀 호출 없애기
재귀함수의 가장 중요한 성질은 재귀함수 안에서 재귀함수이게 꼭 자기 자신일 필요는 없습니다. 두 함수가 서로를 호출하는 재귀도 가능하기 때문이죠.를 다시 호출한다는 점이라고 할 수 있습니다. 그렇다면 재귀함수의 몸체를 다음과 같이 나눌 수 있습니다.
- 기저 조건(base case)을 처리하거나 재귀 호출에 필요한 계산을 하는 부분
- 재귀 호출을 하는 부분
- 재귀 호출의 결과를 이용해 남은 계산을 하는 부분 — 흔히 후속문(continuation)이라고 부르는 것
위 예제에서는 다음과 같이 나눌 수 있습니다.
def sum_of_squares(n: int) -> int:
# 1
if n == 1:
return 1
arg = n - 1
# 2
v = sum_of_squares(arg)
# 3
return v + n ** 2
우리가 원하는 건 재귀 호출을 하지 않는 것이기 때문에, 주어진 재귀함수에서 제대로 실행할 수 있는 부분은 오로지 1번, 재귀 호출 이전 부분입니다. 그리고 남은 재귀 호출과 후속문은 함수 안에서 실행하지 말고, 이 둘을 반환한 뒤 함수 바깥에서 지연(lazy) 실행되게 만들어버립시다. 즉, 바뀐 재귀함수가 반환하는 건 다음 두 가지 중 하나입니다.
- 더이상 재귀 호출이 필요 없을 때, 원래대로 값을 반환
- 지연된 재귀 호출과 후속문을 반환
여기서 후속문은 함수 형태로 반환하면 되고, 재귀 호출은 함수와 인자의 쌍을 반환해도 되고, 이를 캡처하여 매개 변수가 없는 함수 형태로 반환해도 됩니다. 이 글에서는 두 번째 방법을 사용하겠습니다. 또한, 재귀함수 내에 재귀 호출이 여러 번 있는 경우 후속문 역시 1번과 2번 모두 반환할 수 있다는 데 주의해야 합니다.
from typing import TypeVar, Generic, Callable
from dataclasses import dataclass
T = TypeVar('T')
@dataclass
class Done(Generic[T]):
value: T
@dataclass
class Recursive(Generic[T]):
func: Callable[[], 'Trampoline[T]']
cont: Callable[[T], 'Trampoline[T]']
Trampoline = Done[T] | Recursive[T]
def sum_of_squares(n: int) -> Trampoline[int]:
if n == 1:
return Done(1)
# 지연된 재귀 호출
def func() -> Trampoline[int]:
return sum_of_squares(n - 1)
# 재귀 호출의 후속문
def cont(v: int) -> Trampoline[int]:
return Done(v + n ** 2)
return Recursive(func, cont)
두 경우를 구분하기 위해 Done
과 Recursive
라는 두 클래스를 만들었습니다. 원래 타입 T
를 반환하던 재귀함수는 Done[T]
또는 Recursive[T]
를 반환하게 됩니다.
재귀함수 밖에서 재귀 호출 재개하기
트램폴린의 모나드적 구조
이제 바뀐 재귀함수를 호출하여 원래 반환값을 얻어내야 합니다. Done
이 반환되면 그냥 그 값인 거고, Recursive
가 반환되면 func
로 재귀 호출을 하고, 그 결과를 후속문 cont
에 넣어 최종 결과를 얻으면 됩니다. func
의 반환 타입은 T
가 아니라 Trampoline[T]
이고, cont
는 T
를 받으니, 결국 일종의 모나드가 됩니다.
T
를 받는 함수 cont
에 Trampoline[T]
를 넣을 수 있게 bind
함수를 만들어봅시다.
def bind(x: Trampoline[T], f: Callable[[T], Trampoline[T]]) -> Trampoline[T]:
match x:
case Done(value):
return f(value)
case Recursive(func, cont):
# func의 결과를 cont에 넣고, 다시 그 결과를 f에 넣어야 합니다.
return bind(bind(func(), cont), f)
하지만 이렇게 구현한 bind
는 꼬리 재귀가 아닌 재귀함수여서 처음 해결하려고 했던 문제를 여기서 다시 일으키는 꼴이 되어버립니다. 그렇다면 bind
역시 트램폴린으로 구현해버리면 되지 않을까요? bind
에 해당하는 클래스를 만들어 이 상황을 해결해봅시다.
@dataclass
class Bind(Generic[T]):
x: 'Trampoline[T]'
f: Callable[[T], 'Trampoline[T]']
Trampoline = Done[T] | Recursive[T] | Bind[T]
def bind(x: Trampoline[T], f: Callable[[T], Trampoline[T]]) -> Trampoline[T]:
match x:
case Done(value):
return f(value)
case Recursive(func, cont):
# func -> cont -> f 순으로 실행되어야 합니다.
return bind(Bind(func(), cont), f)
case Bind(x1, f1):
match x1:
case Done(value):
# f1 -> f 순으로 실행되어야 합니다.
return bind(f1(value), f)
case Recursive(func, cont):
# func -> cont -> f1 -> f 순으로 실행되어야 합니다.
return bind(Bind(Bind(func(), cont), f1), f)
case Bind(x2, f2):
# f2 -> f1 -> f 순으로 실행되어야 합니다.
# x2를 f2에 넣고, 그 결과를 f1에 넣는 것은 x2를 f1과 f2의
# 합성함수에 넣는 것과 같으므로 다음이 성립합니다.
# Bind(Bind(x2, f2), f1) == Bind(x2, lambda v: Bind(f2(v), f1))
# 중첩된 Bind를 하나로 펼친다고 생각하면 되겠습니다.
return bind(Bind(x2, lambda v: Bind(f2(v), f1)), f)
지금 이대로도 괜찮습니다만 새 구현을 자세히 보면 Recursive
와 Bind
는 func
를 담는지 아니면 func
를 실행한 결과를 담는지의 차이밖에 없습니다. Recursive
를 처리하는 방법이 단순히 내부의 func
를 실행한 값을 담은 Bind
로 만드는 것뿐이기 때문이죠. 둘을 Bind
로 통일하고, 대신 func
만을 담는 Suspend
라는 클래스를 추가합시다. 이제 Recursive
는 Suspend
를 담는 Bind
가 됩니다.
@dataclass
class Suspend(Generic[T]):
func: Callable[[], 'Trampoline[T]']
Trampoline = Done[T] | Suspend[T] | Bind[T]
def sum_of_squares(n: int) -> Trampoline[int]:
if n == 1:
return Done(1)
def func() -> Trampoline[int]:
return sum_of_squares(n - 1)
def cont(v: int) -> Trampoline[int]:
return Done(v + n ** 2)
return Bind(Suspend(func), cont)
def bind(x: Trampoline[T], f: Callable[[T], Trampoline[T]]) -> Trampoline[T]:
match x:
case Done(value):
return f(value)
case Suspend(func):
# func -> f 순으로 실행되어야 합니다.
return bind(func(), f)
case Bind(x1, f1):
match x1:
case Done(value):
# f1 -> f 순으로 실행되어야 합니다.
return bind(f1(value), f)
case Suspend(func):
# func -> f1 -> f 순으로 실행되어야 합니다.
return bind(Bind(func(), f1), f)
case Bind(x2, f2):
# f2 -> f1 -> f 순으로 실행되어야 합니다.
return bind(Bind(x2, lambda v: Bind(f2(v), f1)), f)
마지막으로 꼬리 재귀를 반복문으로 바꿔줍니다. 물론 꼬리 재귀 최적화가 지원되는 언어나 컴파일러를 사용한다면 필요없는 단계입니다.
def bind(x: Trampoline[T], f: Callable[[T], Trampoline[T]]) -> Trampoline[T]:
while True:
match x:
case Done(value):
return f(value)
case Suspend(func):
x = func()
case Bind(x1, f1):
match x1:
case Done(value):
x = f1(value)
case Suspend(func):
x = Bind(func(), f1)
case Bind(x2, f2):
# f1과 f2를 값으로 캡처하기 위해 매개변수로 넘겨줍니다.
# 재귀함수일 땐 문제가 없었지만 반복문에서는 f1과 f2가
# 같은 스코프 내에서 계속 달라지기 때문에 필요합니다.
x = Bind(x2, lambda v, f1=f1, f2=f2: Bind(f2(v), f1))
사실, bind
에서 매개변수 f
는 제일 마지막에 value
에 대해서 딱 한 번만 호출되기 때문에 그냥 x
만 받아서 value
를 내놓는 함수를 만들 수도 있습니다.
def run(x: Trampoline[T]) -> T:
while True:
match x:
case Done(value):
# case 1
return value
case Suspend(func):
# case 2
x = func()
case Bind(x1, f1):
match x1:
case Done(value):
# case 3-1
x = f1(value)
case Suspend(func):
# case 3-2
x = Bind(func(), f1)
case Bind(x2, f2):
# case 3-3
x = Bind(x2, lambda v, f1=f1, f2=f2: Bind(f2(v), f1))
드디어 sum_of_squares
를 실행할 수 있습니다. run
에 그대로 넣어버리면 되죠.
print(run(sum_of_squares(10000))) # 333383335000
어떻게 동작하는 걸까?
run(sum_of_squares(4))
를 실행하면 어떤 일이 일어나는지 차근차근 살펴봅시다. 이때 run
의 반복문 내에서 x
의 값은 다음과 같이 변합니다. 캡처한 값들은 함수 이름 뒤에 대괄호로 표시했습니다.
Bind(Suspend(func[n=4]), cont[n=4])
# case 3-2: 지연된 재귀 호출 func[n=4] 재개
Bind(Bind(Suspend(func[n=3]), cont[n=3]), cont[n=4])
# case 3-3: 중첩된 Bind를 펼침
Bind(Suspend(func[n=3]), lambda v: Bind(cont[n=3](v), cont[n=4]))
# case 3-2: 지연된 재귀 호출 func[n=3] 재개
Bind(Bind(Suspend(func[n=2]), cont[n=2]), lambda v: Bind(cont[n=3](v), cont[n=4]))
# case 3-3: 중첩된 Bind를 펼침
Bind(Suspend(func[n=2]), lambda v: Bind(cont[n=2](v), lambda v: Bind(cont[n=3](v), cont[n=4])))
# case 3-2: 지연된 재귀 호출 func[n=2] 재개
Bind(Done(1), lambda v: Bind(cont[n=2](v), lambda v: Bind(cont[n=3](v), cont[n=4])))
# case 3-1: 후속문 실행
Bind(Done(5), lambda v: Bind(cont[n=3](v), cont[n=4]))
# case 3-1: 후속문 실행
Bind(Done(14), cont[n=4])
# case 3-1: 후속문 실행
Done(30)
# case 1: Done이므로 반환
즉, run
은 반복문 내에서 지연된 재귀 호출을 재개하여 기저 조건까지 도달한 뒤, 후속문을 역순으로 실행하며 기저 조건의 결과로부터 최종 결과를 얻어냅니다. 이 과정에서 중첩된 Bind
는 펼쳐져 뒤쪽 람다 함수에 쌓이게 되고, 이것이 곧 호출 스택의 역할을 합니다.
더 많은 예제
피보나치 수열
def fibonacci(n: int) -> int:
if n == 0:
return 0
if n == 1:
return 1
return fibonacci(n - 1) + fibonacci(n - 2)
아무런 최적화가 없는 재귀함수 형태의 피보나치 수열입니다. 여기에 그대로 트램폴린을 적용해봅시다. 재귀 호출이 두 번인 데 주의합시다.
def fibonacci(n: int) -> Trampoline[int]:
if n == 0:
return Done(0)
if n == 1:
return Done(1)
def func1() -> Trampoline[int]:
return fibonacci(n - 1)
def cont1(v1: int) -> Trampoline[int]:
def func2() -> Trampoline[int]:
return fibonacci(n - 2)
def cont2(v2: int) -> Trampoline[int]:
return Done(v1 + v2)
return Bind(Suspend(func2), cont2)
return Bind(Suspend(func1), cont1)
run(fibonacci(3))
을 호출하면 다음과 같은 과정을 거칩니다.
Bind(Suspend(func1[n=3]), cont1[n=3])
# case 3-2
Bind(Bind(Suspend(func1[n=2]), cont1[n=2]), cont1[n=3])
# case 3-3
Bind(Suspend(func1[n=2]), lambda v: Bind(cont1[n=2](v), cont1[n=3]))
# case 3-2
Bind(Done(1), lambda v: Bind(cont1[n=2](v), cont1[n=3]))
# case 3-1
Bind(Bind(Suspend(func2[n=2]), cont2[n=2,v1=1]), cont1[n=3])
# case 3-3
Bind(Suspend(func2[n=2]), lambda v: Bind(cont2[n=2,v1=1](v), cont1[n=3]))
# case 3-2
Bind(Done(0), lambda v: Bind(cont2[n=2,v1=1](v), cont1[n=3]))
# case 3-1
Bind(Done(1), cont1[n=3])
# case 3-1
Bind(Suspend(func2[n=3]), cont2[n=3,v1=1])
# case 3-2
Bind(Done(1), cont2[n=3,v1=1])
# case 3-1
Done(2)
# case 1
상호 재귀
상호 재귀(mutual recursion)은 둘 이상의 함수가 서로를 호출하는 재귀 형태입니다. 호프스테터 암수 수열(Hofstadter Female and Male sequences)은 전형적인 상호 재귀의 예시입니다.
트램폴린을 이용해 재귀함수로 구현하면 다음과 같습니다.
def f(n: int) -> Trampoline[int]:
if n == 0:
return Done(1)
def func1() -> Trampoline[int]:
return f(n - 1)
def cont1(v1: int) -> Trampoline[int]:
def func2() -> Trampoline[int]:
return m(v1)
def cont2(v2: int) -> Trampoline[int]:
return Done(n - v2)
return Bind(Suspend(func2), cont2)
return Bind(Suspend(func1), cont1)
def m(n: int) -> Trampoline[int]:
if n == 0:
return Done(0)
def func1() -> Trampoline[int]:
return m(n - 1)
def cont1(v1: int) -> Trampoline[int]:
def func2() -> Trampoline[int]:
return f(v1)
def cont2(v2: int) -> Trampoline[int]:
return Done(n - v2)
return Bind(Suspend(func2), cont2)
return Bind(Suspend(func1), cont1)