이진 탐색 트리 (Binary Search Tree) – C++ 구현

파이썬으로 구현해볼까 하다가 연습겸 C++로 구현해봤습니다. 따라서 아래 코드는 별로 좋은 C++ 코드가 아닐 가능성이 큽니다.

 

클래스 정의

먼저 템플릿을 써서 이진 탐색 트리의 몸체를 정의해줍니다.

template<typename KeyType>
class BinarySearchTree {
private:
    class Node;

    static Node *const NIL;
    Node *root, *node_max, *node_min;
    size_t size;

public:
    class iterator;

    BinarySearchTree() : root(NIL), node_max(NIL), node_min(NIL), size(0) {}
    ~BinarySearchTree() {
        if (root != NIL)
            delete root;
    }

    iterator insert(KeyType key);
    iterator search(KeyType key);
    void remove(KeyType key);
    void remove(iterator it);

    inline iterator begin(void);
    inline iterator end(void);
    inline iterator rbegin(void);
    inline iterator rend(void);

    void print(void);
};

BinarySearchTree 클래스는 private으로 Node 클래스와 public으로 iterator 클래스를 가지고 있습니다. 또 static 멤버 변수로 앞 포스트에서도 설명했던 NIL 노드를 가지고 있고, 그 외에 private으로 root, node_max, node_min를 가집니다. 각각 루트 노드, 최댓값을 가지는 노드, 최솟값을 가지는 노드를 가리킵니다. 사실 루트 노드만 저장해줘도 트리 구현엔 문제가 없지만 반복자(iterator)를 구현할 때 있으면 좋습니다. size 변수는 이름에서 유추할 수 있듯이 트리 크기를 저장합니다. 메서드로는 삽입, 검색, 삭제 연산을 해줄 insert, search, remove가 있고 STL에서 흔히 볼 수 있는 메서드 begin, end, rbegin, rend가 있습니다. 마지막 print는 디버깅용으로, 트리를 출력해줍니다.

일단 NIL 노드를 할당합시다. 덧붙이자면 앞 포스트의 그림에선 NIL 노드를 자식이 없는 노드마다 일일이 따로 그려줬지만, 실제로 NIL 노드를 구현할 경우에는 보통 유일한 NIL 노드가 하나 있고, 자식이 없으면 그 NIL 노드를 가리키게 하는 형태로 구현합니다.

template<typename KeyType>
typename BinarySearchTree<KeyType>::Node *const BinarySearchTree<KeyType>::NIL = new BinarySearchTree<KeyType>::Node();

그 다음 중첩 클래스를 정의합니다.

template<typename KeyType>
class BinarySearchTree<KeyType>::Node {
public:
    KeyType key;
    Node *left, *right, *parent;

    Node() : left(nullptr), right(nullptr), parent(nullptr) {}
    Node(KeyType k) : key(k), left(NIL), right(NIL), parent(NIL) {}
    ~Node() {
        if (this->left != NIL)
            delete this->left;
        if (this->right != NIL)
            delete this->right;
    }

    Node *predecessor(void);
    Node *successor(void);
};

Node 클래스는 키와, 왼쪽 자식, 오른쪽 자식, 부모를 가리키는 포인터를 가집니다. 생성자에서는 세 포인터가 NIL 노드를 가리키게 하고, 소멸자에서는 왼쪽 자식과 오른쪽 자식을 해제해주는 과정이 필수적으로 필요합니다. 마지막에 보이는 두 메서드 predecessorsuccessor는 각각 키를 크기순으로 정렬했을 때 바로 이전 노드(predecessor, 전임자)와 바로 다음 노드(successor, 후임자)를 계산해줍니다. 역시 반복자를 구현할 때 필요합니다.

template<typename KeyType>
typename BinarySearchTree<KeyType>::Node *BinarySearchTree<KeyType>::Node::predecessor(void) {
    Node *cur = this;
    if (cur == NIL)
        return NIL;
    if (cur->left != NIL) {
        cur = cur->left;
        while (cur->right != NIL)
            cur = cur->right;
        return cur;
    }
    while (cur->parent != NIL && cur->parent->left == cur)
        cur = cur->parent;
    return cur->parent;
}

template<typename KeyType>
typename BinarySearchTree<KeyType>::Node *BinarySearchTree<KeyType>::Node::successor(void) {
    Node *cur = this;
    if (cur == NIL)
        return NIL;
    if (cur->right != NIL) {
        cur = cur->right;
        while (cur->left != NIL)
            cur = cur->left;
        return cur;
    }
    while (cur->parent != NIL && cur->parent->right == cur)
        cur = cur->parent;
    return cur->parent;
}

predecessorsuccessor 메서드입니다. 어떤 노드가 오른쪽 자식을 가질 때 successor를 구하는 법은 앞 포스트에서 간략히 설명했죠. 오른쪽 서브트리에서 제일 왼쪽 노드를 찾으면 됩니다. 만약 오른쪽 자식이 없다면 successor는 이 노드가 왼쪽 서브트리의 가장 오른쪽 노드인 노드입니다. 즉, 현재 노드에서 시작해서 이 노드가 부모의 오른쪽 자식이면 부모로 올라가고, 아니면 그때 부모가 successor죠. predecessor는 successor와 반대로 하면 됩니다.

template<typename KeyType>
class BinarySearchTree<KeyType>::iterator {
private:
    Node *current;

public:
    iterator() {}
    iterator(Node *node) : current(node) {}

    inline iterator &operator++(void) {
        this->current = this->current->successor();
        return *this;
    }
    inline iterator operator++(int) {
        iterator tmp(this->current);
        this->current = this->current->successor();
        return tmp;
    }
    inline iterator &operator--(void) {
        this->current = this->current->predecessor();
        return *this;
    }
    inline iterator operator--(int) {
        iterator tmp(this->current);
        this->current = this->current->predecessor();
        return tmp;
    }
    inline KeyType operator*(void) {
        return this->current->key;
    }
    inline bool operator==(const iterator other) {
        return this->current == other.current;
    }
    inline bool operator!=(const iterator other) {
        return this->current != other.current;
    }

    friend void BinarySearchTree<KeyType>::remove(iterator it);
};

iterator 클래스는 현재 방문 중인 노드를 가리키는 변수 current를 가집니다. STL에서는 크기의 역순으로 방문하기 위해 reverse_iterator라는 걸 따로 정의했는데, 귀찮아서 그냥 감소 연산자를 써서 역순 방문이 가능하도록 만들었습니다. 증가 연산자는 현재 노드의 successor, 감소 연산자는 predecessor를 방문하게 됩니다. 참조 연산자는 현재 노드의 키를 리턴합니다. 그 외에 비교 연산자를 구현했습니다.

template<typename KeyType>
typename BinarySearchTree<KeyType>::iterator BinarySearchTree<KeyType>::insert(KeyType key) {
    this->size++;
    if (this->root == NIL) {
        // make new node for root
        this->root = new Node(key);
        this->node_max = this->node_min = this->root;
        return iterator(this->root);
    }
    Node *cur = this->root;
    while (true) {
        if (key < cur->key) {
            // left subtree
            if (cur->left == NIL) {
                cur->left = new Node(key);
                cur->left->parent = cur;
                if (cur == this->node_min)
                    this->node_min = cur->left;
                cur = cur->left;
                break;
            }
            cur = cur->left;
        }
        else if (cur->key < key) {
            // right subtree
            if (cur->right == NIL) {
                cur->right = new Node(key);
                cur->right->parent = cur;
                if (cur == this->node_max)
                    this->node_max = cur->right;
                cur = cur->right;
                break;
            }
            cur = cur->right;
        }
        else
            // key already exist
            return iterator(NIL);
    }
    return iterator(cur);
}

insert는 추가할 키를 받아서 새로 삽입한 노드를 가리키는 반복자를 리턴하도록 했습니다. (이미 있는 키인 경우 NIL을 가리키는 반복자를 리턴) 트리가 비어있는 경우는 예외로 처리합니다.

template<typename KeyType>
typename BinarySearchTree<KeyType>::iterator BinarySearchTree<KeyType>::search(KeyType key) {
    Node *cur = this->root;
    while (cur != NIL && cur->key != key) {
        if (key < cur->key)
            // left subtree
            cur = cur->left;
        else
            // right subtree
            cur = cur->right;
    }
    return iterator(cur);
}

search입니다. 역시 키를 받아서 그 키를 가진 노드를 가리키는 반복자를 리턴합니다. 훨씬 간단하죠.

template<typename KeyType>
void BinarySearchTree<KeyType>::remove(BinarySearchTree<KeyType>::iterator it) {
    Node *v = it.current;
    if (v == NIL)
        return;
    this->size--;
    if (v->left != NIL && v->right != NIL) {
        // two children
        // replace v by its successor
        Node *s = v->successor();
        v->key = s->key;
        v = s;
    }
    if (v == this->node_max)
        this->node_max = v->predecessor();
    if (v == this->node_min)
        this->node_min = v->successor();
    if (v->left == NIL && v->right == NIL) {
        // no children
        if (v == this->root)
            // if v is root, make root be NIL
            this->root = NIL;
        else if (v == v->parent->left)
            // v is left child of its parent
            v->parent->left = NIL;
        else
            // v is right child of its parent
            v->parent->right = NIL;
    }
    else {
        // one child
        Node *child = (v->left == NIL? v->right : v->left);
        if (v == this->root) {
            // if v is root, make child be root
            child->parent = NIL;
            this->root = child;
        }
        else if (v == v->parent->left) {
            v->parent->left = child;
            child->parent = v->parent;
        }
        else {
            v->parent->right = child;
            child->parent = v->parent;
        }
    }
    delete v;
}

 

remove는 매우 복잡합니다. 먼저 반복자를 받아서 반복자가 가리키는 노드를 삭제하도록 해봅시다. 첫 if 문은 삭제할 노드의 자식이 둘일 경우 successor의 키를 삭제할 노드에 복사하고 삭제할 노드를 successor로 변경해 줍니다. 이 처리를 하면 삭제할 노드는 무조건 자식이 하나 아니면 없으므로, 각 경우에 대해 적절히 노드를 삭제하면 됩니다.

template<typename KeyType>
void BinarySearchTree<KeyType>::remove(KeyType key) {
    this->remove(this->search(key));
}

키를 받아서 삭제를 하고 싶으면 search로 노드를 찾으면 됩니다.

template<typename KeyType>
typename BinarySearchTree<KeyType>::iterator BinarySearchTree<KeyType>::begin(void) {
    return iterator(this->node_min);
}

template<typename KeyType>
typename BinarySearchTree<KeyType>::iterator BinarySearchTree<KeyType>::end(void) {
    return iterator(NIL);
}

template<typename KeyType>
typename BinarySearchTree<KeyType>::iterator BinarySearchTree<KeyType>::rbegin(void) {
    return iterator(this->node_max);
}

template<typename KeyType>
typename BinarySearchTree<KeyType>::iterator BinarySearchTree<KeyType>::rend(void) {
    return iterator(NIL);
}

위 네 메서드는 반복자를 써서 트리를 순회할 때 필요한 시작점과 끝점을 계산해줍니다.

template<typename KeyType>
void BinarySearchTree<KeyType>::print(void) {
    cout << "size: " << this->size << endl;
    vector<Node *> cur_level = {this->root};
    vector<vector<Node *>> levels;
    bool no_leaf = false;
    while (!no_leaf) {
        levels.push_back(cur_level);
        cur_level.clear();
        no_leaf = true;
        for (Node *p : *levels.rbegin()) {
            if (p != NIL && p->left != NIL) {
                cur_level.push_back(p->left);
                no_leaf = false;
            }
            else
                cur_level.push_back(NIL);
            if (p != NIL && p->right != NIL) {
                cur_level.push_back(p->right);
                no_leaf = false;
            }
            else
                cur_level.push_back(NIL);
        }
    }
    int width = 4 * levels.rbegin()->size();
    for (vector<Node *> &lv : levels) {
        for (Node *node : lv) {
            string s;
            if (node != NIL)
                s = to_string(node->key);
            cout << string((width - s.size()) / 2, ' ') << s << string((width - s.size() + 1) / 2, ' ');
        }
        cout << endl;
        width /= 2;
    }
}

마지막으로 디버깅용 메서드 print입니다. BFS를 이용해 각 높이별로 노드 목록을 만들고, 폭에 맞춰서 적절히 출력합니다.

이 외에도 트리 크기, 트리가 비었는지 여부를 계산할 메서드가 필요하지만, 생략하겠습니다.

BinarySearchTree<int> b;
b.insert(4);
b.insert(2);
b.insert(6);
b.insert(1);
b.insert(5);
b.insert(7);
b.insert(3);

BinarySearchTree<int>::iterator it;
for (it = b.begin(); it != b.end(); it++)
    cout << *it << ' ';
cout << endl;
for (it = b.rbegin(); it != b.rend(); it--)
    cout << *it << ' ';
cout << endl;

b.remove(5);
it = b.search(2);
b.remove(it);

for (it = b.begin(); it != b.end(); it++)
    cout << *it << ' ';
cout << endl;
for (it = b.rbegin(); it != b.rend(); it--)
    cout << *it << ' ';
cout << endl;

BinarySearchTree클래스는 위와 같이 사용할 수 있습니다.