Python - 二分探索木を書いてみた

Python二分探索木を実装してみました。 追加、検索、削除ひと通りあります。

実装コード

# -*- coding: utf-8 -*-
class Node:
    def __init__(self, label, left=None, right=None):
        self.label = label
        self.left = left
        self.right = right

    def add(self, label):
        return self.__add(Node(label))

    def __add(self, append):
        parent = self
        current = self
        is_left = False
        while current is not None:
            if current.label == append.label:
                # 重複データの場合は何もしない
                return self
            elif append.label < current.label:
                parent = current
                current = current.left
                is_left = True
            else:
                parent = current
                current = current.right
                is_left = False

        self.__reconnect(parent, append, is_left)
        return self

    def get(self, key):
        return self.__get(Node(key))

    def __get(self, search):
        current = self
        while current is not None:
            if current.label == search.label:
                return current
            elif search.label < current.label:
                current = current.left
            else:
                current = current.right

        return Node(None)

    def delete(self, key):
        return self.__delete(Node(key))

    def __delete(self, search):
        if self.label == search.label and self.is_leaf():
            self.label = None
            return True

        parent = self
        current = self
        is_left = False
        while current is not None:
            if current.label == search.label:
                self.__reconnect(parent, self.__fetch_delete_child(current), is_left)
                return True
            elif search.label < current.label:
                parent = current
                current = current.left
                is_left = True
            else:
                parent = current
                current = current.right
                is_left = False

        return False

    def __fetch_delete_child(self, target):
        if target.is_leaf():
            return None
        elif target.left is not None and target.right is not None:
            left_bottom, left_bottom_parent = self.__min_and_parent(target.right)
            # left_bottomは左部分木の末端なので、リーフじゃなければ必ず右部分木しかない
            # TODO: fetchと言いながらここで参照を書き換えている(副作用)
            left_bottom_parent.left = None if left_bottom.is_leaf() else left_bottom.right
            left_bottom.left = target.left
            left_bottom.right = target.right
            return left_bottom
        else:
            # 左部分木のみ or 右部分木のみ
            return target.left if target.left is not None else target.right

    def __min_and_parent(self, node):
        """ 左部分木の最下部(リーフ)とその親を取得 """
        left_bottom_parent = node
        left_bottom = node
        while left_bottom.left is not None:
            left_bottom_parent = left_bottom
            left_bottom = left_bottom.left
        return (left_bottom, left_bottom_parent)

    def __reconnect(self, parent, child, is_left):
        if is_left:
            parent.left = child
        else:
            parent.right = child

    def is_leaf(self):
        return self.left is None and self.right is None

    def __str__(self):
        if self.is_leaf():
            return "(Leaf %s)" % (self.label)
        return "(Node %s %s %s)" % (self.label, str(self.left), str(self.right))

    def __repr__(self):
        return str(self)

    def __eq__(self, other):
        if isinstance(other, Node):
            return str(self) == str(other)
        return False

    def __ne__(self, other):
        return not self.__eq__(other)

テストコード

import unittest

class Test(unittest.TestCase):
    def setUp(self):
        pass

    def test_eq(self):
        """ ノードの等値性 """
        self.assertEquals(Node(1), Node(1))
        self.assertNotEquals(Node(10), Node(1))

    def test_add(self):
        """ ノード追加 """
        node = Node(5).add(3).add(2).add(10).add(20)
        expected = Node(5,
                        Node(3, Node(2)),
                        Node(10, None, Node(20)))
        self.assertEquals(node, expected)

    def test_get(self):
        """ ノード検索 """
        node = Node(5,
                    Node(3, Node(2)),
                    Node(10, None, Node(20)))
        self.assertEquals(node.get(10), Node(10, None, Node(20)))
        self.assertEquals(node.get(999), Node(None))

    def test_is_leaf(self):
        """ リーフかどうか """
        self.assertEquals(Node(10).is_leaf(), True)
        self.assertEquals(Node(10, Node(5)).is_leaf(), False)
        self.assertEquals(Node(10, None, Node(15)).is_leaf(), False)
        self.assertEquals(Node(10, Node(5), Node(15)).is_leaf(), False)

    def test_delete(self):
        """ ノード削除 """

        """ 存在しないノードは消せない """
        node = Node(10)
        self.assertEquals(node.delete(7), False)

        """ ルート(リーフ)のみ """
        node = Node(10)
        self.assertEquals(node.delete(10), True)
        self.assertEquals(node, Node(None))

        """ 該当ノード == リーフ(左) """
        node = Node(10, Node(5))
        self.assertEquals(node.delete(5), True)
        self.assertEquals(node, Node(10))

        """ 該当ノード == リーフ(右) """
        node = Node(10, None, Node(15))
        self.assertEquals(node.delete(15), True)
        self.assertEquals(node, Node(10))

        """ 子が左部分木のみ """
        node = Node(10, Node(5, Node(3, None, Node(4))))
        self.assertEquals(node.delete(5), True)
        self.assertEquals(node, Node(10, Node(3, None, Node(4))))

        """ 子が右部分木のみ """
        node = Node(10, None, Node(15, None, Node(13, None, Node(14))))
        self.assertEquals(node.delete(15), True)
        self.assertEquals(node, Node(10, None, Node(13, None, Node(14))))

        """ 子が両方ともある """
        """
        Before:
                  ┌ーーー20ーーー┐
                  │              │
             ┌ーー7ーー┐        23ーー┐
             │         │              │
         ┌ー4ー┐  ┌ー18              29
         │     │  │
          2     5   10ー┐
                        15
        """
        node = Node(20,
                    Node(7, Node(4, Node(2), Node(5)), Node(18, Node(10, None, Node(15)))),
                    Node(23, None, Node(29)))

        # 上記ツリーから7を削除
        self.assertEquals(node.delete(7), True)

        """
        After:
                  ┌ーーー20ーーー┐
                  │              │
            ┌ーー10ーー┐        23ーー┐
            │          │              │
         ┌ー4ー┐  ┌ー18              29
         │     │  │
          2     5   15
        """
        expected = Node(20,
                    Node(10, Node(4, Node(2), Node(5)), Node(18, Node(15))),
                    Node(23, None, Node(29)))
        self.assertEquals(node, expected)


if __name__ == '__main__':
    unittest.main()

感想

  • 定義したテストコードは通るけど、まだまだ不安
    • テストパターンを洗い出すことが不安を消すことに対して重要だと実感
  • 削除の実装がめんどかった
    • ググって出てくるサンプルは大概削除は実装されていないけど、めんどいからなんだろうな
  • 実装中、whileの条件間違えて何回か無限ループを発生させた
  • もう少し綺麗に出来ないかとも思うけど、そんなに時間割けないからこれでストップ
    • スキル不足を体感した