트램폴린

2024년 8월 19일

재귀함수의 문제점

재귀함수는 이해하기 은근히 어렵다는 점 때문에 반복문에 비해 선호되지 않는 편이지만, 재귀적인 자료구조(트리 등)을 다루거나 함수형 프로그래밍을 할 때에는 필수적입니다. 하지만 현실적으로는 재귀함수 호출이 누적되다보면 스택 오버플로가 발생한다는 큰 문제점이 있습니다. 이걸 해결하기 위해 꼬리 재귀 최적화(tail-call optimization)와 같은 기법이 있지만, 컴파일러에 따라 지원하지 않기도 하고 모든 재귀함수에 적용할 수 있는 최적화도 아닙니다.


# 1 + 4 + 9 + ... + n ** 2
def sum_of_squares(n: int) -> int:
    if n == 1:
        return 1
    return sum_of_squares(n - 1) + n ** 2

print(sum_of_squares(10000))  # 파이썬은 기본적으로 호출 깊이를 1000으로 제한하기 때문에 에러가 발생
                              # 게다가 파이썬은 꼬리 재귀 최적화를 지원하지 않음

그렇다면 어떻게 해야 재귀 호출 없이 재귀함수를 구현할 수 있을까요? 함수형 프로그래밍에서 트램폴린(trampoline)이라고 불리는 기법을 사용하면 재귀함수에서 재귀 호출을 효과적으로 없앨 수 있습니다.

재귀함수 안에서 재귀 호출 없애기

재귀함수의 가장 중요한 성질은 재귀함수 안에서 재귀함수이게 꼭 자기 자신일 필요는 없습니다. 두 함수가 서로를 호출하는 재귀도 가능하기 때문이죠.를 다시 호출한다는 점이라고 할 수 있습니다. 그렇다면 재귀함수의 몸체를 다음과 같이 나눌 수 있습니다.

  1. 기저 조건(base case)을 처리하거나 재귀 호출에 필요한 계산을 하는 부분
  2. 재귀 호출을 하는 부분
  3. 재귀 호출의 결과를 이용해 남은 계산을 하는 부분 — 흔히 후속문(continuation)이라고 부르는 것

위 예제에서는 다음과 같이 나눌 수 있습니다.


def sum_of_squares(n: int) -> int:
    # 1
    if n == 1:
        return 1
    arg = n - 1

    # 2
    v = sum_of_squares(arg)

    # 3
    return v + n ** 2

우리가 원하는 건 재귀 호출을 하지 않는 것이기 때문에, 주어진 재귀함수에서 제대로 실행할 수 있는 부분은 오로지 1번, 재귀 호출 이전 부분입니다. 그리고 남은 재귀 호출과 후속문은 함수 안에서 실행하지 말고, 이 둘을 반환한 뒤 함수 바깥에서 지연(lazy) 실행되게 만들어버립시다. 즉, 바뀐 재귀함수가 반환하는 건 다음 두 가지 중 하나입니다.

  1. 더이상 재귀 호출이 필요 없을 때, 원래대로 값을 반환
  2. 지연된 재귀 호출과 후속문을 반환

여기서 후속문은 함수 형태로 반환하면 되고, 재귀 호출은 함수와 인자의 쌍을 반환해도 되고, 이를 캡처하여 매개 변수가 없는 함수 형태로 반환해도 됩니다. 이 글에서는 두 번째 방법을 사용하겠습니다. 또한, 재귀함수 내에 재귀 호출이 여러 번 있는 경우 후속문 역시 1번과 2번 모두 반환할 수 있다는 데 주의해야 합니다.


from typing import TypeVar, Generic, Callable
from dataclasses import dataclass

T = TypeVar('T')

@dataclass
class Done(Generic[T]):
    value: T

@dataclass
class Recursive(Generic[T]):
    func: Callable[[], 'Trampoline[T]']
    cont: Callable[[T], 'Trampoline[T]']

Trampoline = Done[T] | Recursive[T]

def sum_of_squares(n: int) -> Trampoline[int]:
    if n == 1:
        return Done(1)

    # 지연된 재귀 호출
    def func() -> Trampoline[int]:
        return sum_of_squares(n - 1)
    # 재귀 호출의 후속문
    def cont(v: int) -> Trampoline[int]:
        return Done(v + n ** 2)
    return Recursive(func, cont)

두 경우를 구분하기 위해 DoneRecursive라는 두 클래스를 만들었습니다. 원래 타입 T를 반환하던 재귀함수는 Done[T] 또는 Recursive[T]를 반환하게 됩니다.

재귀함수 밖에서 재귀 호출 재개하기

트램폴린의 모나드적 구조

이제 바뀐 재귀함수를 호출하여 원래 반환값을 얻어내야 합니다. Done이 반환되면 그냥 그 값인 거고, Recursive가 반환되면 func로 재귀 호출을 하고, 그 결과를 후속문 cont에 넣어 최종 결과를 얻으면 됩니다. func의 반환 타입은 T가 아니라 Trampoline[T]이고, contT를 받으니, 결국 일종의 모나드가 됩니다.

T를 받는 함수 contTrampoline[T]를 넣을 수 있게 bind 함수를 만들어봅시다.


def bind(x: Trampoline[T], f: Callable[[T], Trampoline[T]]) -> Trampoline[T]:
    match x:
        case Done(value):
            return f(value)
        case Recursive(func, cont):
            # func의 결과를 cont에 넣고, 다시 그 결과를 f에 넣어야 합니다.
            return bind(bind(func(), cont), f)

하지만 이렇게 구현한 bind는 꼬리 재귀가 아닌 재귀함수여서 처음 해결하려고 했던 문제를 여기서 다시 일으키는 꼴이 되어버립니다. 그렇다면 bind 역시 트램폴린으로 구현해버리면 되지 않을까요? bind에 해당하는 클래스를 만들어 이 상황을 해결해봅시다.


@dataclass
class Bind(Generic[T]):
    x: 'Trampoline[T]'
    f: Callable[[T], 'Trampoline[T]']

Trampoline = Done[T] | Recursive[T] | Bind[T]

def bind(x: Trampoline[T], f: Callable[[T], Trampoline[T]]) -> Trampoline[T]:
    match x:
        case Done(value):
            return f(value)
        case Recursive(func, cont):
            # func -> cont -> f 순으로 실행되어야 합니다.
            return bind(Bind(func(), cont), f)
        case Bind(x1, f1):
            match x1:
                case Done(value):
                    # f1 -> f 순으로 실행되어야 합니다.
                    return bind(f1(value), f)
                case Recursive(func, cont):
                    # func -> cont -> f1 -> f 순으로 실행되어야 합니다.
                    return bind(Bind(Bind(func(), cont), f1), f)
                case Bind(x2, f2):
                    # f2 -> f1 -> f 순으로 실행되어야 합니다.
                    # x2를 f2에 넣고, 그 결과를 f1에 넣는 것은 x2를 f1과 f2의
                    # 합성함수에 넣는 것과 같으므로 다음이 성립합니다.
                    # Bind(Bind(x2, f2), f1) == Bind(x2, lambda v: Bind(f2(v), f1))
                    # 중첩된 Bind를 하나로 펼친다고 생각하면 되겠습니다.
                    return bind(Bind(x2, lambda v: Bind(f2(v), f1)), f)

지금 이대로도 괜찮습니다만 새 구현을 자세히 보면 RecursiveBindfunc를 담는지 아니면 func를 실행한 결과를 담는지의 차이밖에 없습니다. Recursive를 처리하는 방법이 단순히 내부의 func를 실행한 값을 담은 Bind로 만드는 것뿐이기 때문이죠. 둘을 Bind로 통일하고, 대신 func만을 담는 Suspend라는 클래스를 추가합시다. 이제 RecursiveSuspend를 담는 Bind가 됩니다.


@dataclass
class Suspend(Generic[T]):
    func: Callable[[], 'Trampoline[T]']

Trampoline = Done[T] | Suspend[T] | Bind[T]

def sum_of_squares(n: int) -> Trampoline[int]:
    if n == 1:
        return Done(1)

    def func() -> Trampoline[int]:
        return sum_of_squares(n - 1)
    def cont(v: int) -> Trampoline[int]:
        return Done(v + n ** 2)
    return Bind(Suspend(func), cont)

def bind(x: Trampoline[T], f: Callable[[T], Trampoline[T]]) -> Trampoline[T]:
    match x:
        case Done(value):
            return f(value)
        case Suspend(func):
            # func -> f 순으로 실행되어야 합니다.
            return bind(func(), f)
        case Bind(x1, f1):
            match x1:
                case Done(value):
                    # f1 -> f 순으로 실행되어야 합니다.
                    return bind(f1(value), f)
                case Suspend(func):
                    # func -> f1 -> f 순으로 실행되어야 합니다.
                    return bind(Bind(func(), f1), f)
                case Bind(x2, f2):
                    # f2 -> f1 -> f 순으로 실행되어야 합니다.
                    return bind(Bind(x2, lambda v: Bind(f2(v), f1)), f)

마지막으로 꼬리 재귀를 반복문으로 바꿔줍니다. 물론 꼬리 재귀 최적화가 지원되는 언어나 컴파일러를 사용한다면 필요없는 단계입니다.


def bind(x: Trampoline[T], f: Callable[[T], Trampoline[T]]) -> Trampoline[T]:
    while True:
        match x:
            case Done(value):
                return f(value)
            case Suspend(func):
                x = func()
            case Bind(x1, f1):
                match x1:
                    case Done(value):
                        x = f1(value)
                    case Suspend(func):
                        x = Bind(func(), f1)
                    case Bind(x2, f2):
                        # f1과 f2를 값으로 캡처하기 위해 매개변수로 넘겨줍니다.
                        # 재귀함수일 땐 문제가 없었지만 반복문에서는 f1과 f2가
                        # 같은 스코프 내에서 계속 달라지기 때문에 필요합니다.
                        x = Bind(x2, lambda v, f1=f1, f2=f2: Bind(f2(v), f1))

사실, bind에서 매개변수 f는 제일 마지막에 value에 대해서 딱 한 번만 호출되기 때문에 그냥 x만 받아서 value를 내놓는 함수를 만들 수도 있습니다.


def run(x: Trampoline[T]) -> T:
    while True:
        match x:
            case Done(value):
                # case 1
                return value
            case Suspend(func):
                # case 2
                x = func()
            case Bind(x1, f1):
                match x1:
                    case Done(value):
                        # case 3-1
                        x = f1(value)
                    case Suspend(func):
                        # case 3-2
                        x = Bind(func(), f1)
                    case Bind(x2, f2):
                        # case 3-3
                        x = Bind(x2, lambda v, f1=f1, f2=f2: Bind(f2(v), f1))

드디어 sum_of_squares를 실행할 수 있습니다. run에 그대로 넣어버리면 되죠.


print(run(sum_of_squares(10000)))  # 333383335000

어떻게 동작하는 걸까?

run(sum_of_squares(4))를 실행하면 어떤 일이 일어나는지 차근차근 살펴봅시다. 이때 run의 반복문 내에서 x의 값은 다음과 같이 변합니다. 캡처한 값들은 함수 이름 뒤에 대괄호로 표시했습니다.


Bind(Suspend(func[n=4]), cont[n=4])
# case 3-2: 지연된 재귀 호출 func[n=4] 재개
Bind(Bind(Suspend(func[n=3]), cont[n=3]), cont[n=4])
# case 3-3: 중첩된 Bind를 펼침
Bind(Suspend(func[n=3]), lambda v: Bind(cont[n=3](v), cont[n=4]))
# case 3-2: 지연된 재귀 호출 func[n=3] 재개
Bind(Bind(Suspend(func[n=2]), cont[n=2]), lambda v: Bind(cont[n=3](v), cont[n=4]))
# case 3-3: 중첩된 Bind를 펼침
Bind(Suspend(func[n=2]), lambda v: Bind(cont[n=2](v), lambda v: Bind(cont[n=3](v), cont[n=4])))
# case 3-2: 지연된 재귀 호출 func[n=2] 재개
Bind(Done(1), lambda v: Bind(cont[n=2](v), lambda v: Bind(cont[n=3](v), cont[n=4])))
# case 3-1: 후속문 실행
Bind(Done(5), lambda v: Bind(cont[n=3](v), cont[n=4]))
# case 3-1: 후속문 실행
Bind(Done(14), cont[n=4])
# case 3-1: 후속문 실행
Done(30)
# case 1: Done이므로 반환

즉, run은 반복문 내에서 지연된 재귀 호출을 재개하여 기저 조건까지 도달한 뒤, 후속문을 역순으로 실행하며 기저 조건의 결과로부터 최종 결과를 얻어냅니다. 이 과정에서 중첩된 Bind는 펼쳐져 뒤쪽 람다 함수에 쌓이게 되고, 이것이 곧 호출 스택의 역할을 합니다.

더 많은 예제

피보나치 수열


def fibonacci(n: int) -> int:
    if n == 0:
        return 0
    if n == 1:
        return 1
    return fibonacci(n - 1) + fibonacci(n - 2)

아무런 최적화가 없는 재귀함수 형태의 피보나치 수열입니다. 여기에 그대로 트램폴린을 적용해봅시다. 재귀 호출이 두 번인 데 주의합시다.


def fibonacci(n: int) -> Trampoline[int]:
    if n == 0:
        return Done(0)
    if n == 1:
        return Done(1)

    def func1() -> Trampoline[int]:
        return fibonacci(n - 1)
    def cont1(v1: int) -> Trampoline[int]:
        def func2() -> Trampoline[int]:
            return fibonacci(n - 2)
        def cont2(v2: int) -> Trampoline[int]:
            return Done(v1 + v2)
        return Bind(Suspend(func2), cont2)
    return Bind(Suspend(func1), cont1)

run(fibonacci(3))을 호출하면 다음과 같은 과정을 거칩니다.


Bind(Suspend(func1[n=3]), cont1[n=3])
# case 3-2
Bind(Bind(Suspend(func1[n=2]), cont1[n=2]), cont1[n=3])
# case 3-3
Bind(Suspend(func1[n=2]), lambda v: Bind(cont1[n=2](v), cont1[n=3]))
# case 3-2
Bind(Done(1), lambda v: Bind(cont1[n=2](v), cont1[n=3]))
# case 3-1
Bind(Bind(Suspend(func2[n=2]), cont2[n=2,v1=1]), cont1[n=3])
# case 3-3
Bind(Suspend(func2[n=2]), lambda v: Bind(cont2[n=2,v1=1](v), cont1[n=3]))
# case 3-2
Bind(Done(0), lambda v: Bind(cont2[n=2,v1=1](v), cont1[n=3]))
# case 3-1
Bind(Done(1), cont1[n=3])
# case 3-1
Bind(Suspend(func2[n=3]), cont2[n=3,v1=1])
# case 3-2
Bind(Done(1), cont2[n=3,v1=1])
# case 3-1
Done(2)
# case 1

상호 재귀

상호 재귀(mutual recursion)은 둘 이상의 함수가 서로를 호출하는 재귀 형태입니다. 호프스테터 암수 수열(Hofstadter Female and Male sequences)은 전형적인 상호 재귀의 예시입니다.

\begin{align*} F(0) &= 1 \\ M(0) &= 0 \\ F(n) &= n - M(F(n - 1)) \quad \text{if } n > 0 \\ M(n) &= n - F(M(n - 1)) \quad \text{if } n > 0 \end{align*}

트램폴린을 이용해 재귀함수로 구현하면 다음과 같습니다.


def f(n: int) -> Trampoline[int]:
    if n == 0:
        return Done(1)

    def func1() -> Trampoline[int]:
        return f(n - 1)
    def cont1(v1: int) -> Trampoline[int]:
        def func2() -> Trampoline[int]:
            return m(v1)
        def cont2(v2: int) -> Trampoline[int]:
            return Done(n - v2)
        return Bind(Suspend(func2), cont2)
    return Bind(Suspend(func1), cont1)

def m(n: int) -> Trampoline[int]:
    if n == 0:
        return Done(0)

    def func1() -> Trampoline[int]:
        return m(n - 1)
    def cont1(v1: int) -> Trampoline[int]:
        def func2() -> Trampoline[int]:
            return f(v1)
        def cont2(v2: int) -> Trampoline[int]:
            return Done(n - v2)
        return Bind(Suspend(func2), cont2)
    return Bind(Suspend(func1), cont1)

참고 문헌