LoginSignup

Why not login to Qiita and try out its useful features?

We'll deliver articles that match you.

You can read useful information later.

74
54

More than 3 years have passed since last update.

平衡二分木を実装する

Last updated at Posted at 2020-06-15

回転のいらない平衡二分木を実装したい

Python では組み込み関数に平衡二分木を扱えるものがないので自作する必要があります。よくある平衡二分木では、平衡を保つために「回転」の操作をしないといけないのですが、この処理が実装的にもパフォーマンス的にも結構重いので、なるべく回避する方法を使いたいです。
ここでは、 ピボット値 を設定することで回転のいらない平衡二分木 1 を実装する方法を紹介します。 Python のコードも示します。

やりたいこと

  • 整数値を取る平衡二分木を作る
  • 取りうる整数値xの範囲を1x<Lとするとき、構築はO(1)、挿入・削除・検索はO(logL)でできる

なお本稿では簡単のため、同じ値を複数個追加することはない(すでに存在する値を追加しようとした場合は何も起きない)としています。
必要な場合、①各整数の個数を管理するための Dict を使う、②IDを末尾に入れて Unique にする、などによって容易に修正できます。問題例2でも扱います。

方針(制約)

通常の二分探索木は、追加の順番によっては片側に伸びてしまって平衡にならないことがあります。平衡を保つためには、ある種の制約を課してやる必要があります。

復習(AVL木、赤黒木)

AVL 木は「どのノードの左右部分木の高さの差も1以下」という制約を課すことで平衡を担保しています。また赤黒木は各ノードに黒または赤の色を対応させて、「赤ノードの子は黒」、「任意の葉から根までにある黒ノードの個数は一定」という制約を課すことで平衡を担保しています。
AVL 木も赤黒木も、制約を満たさなくなると「回転」をすることによって平衡に保つという方法を取っていました。

本稿での制約

扱う整数xの範囲は1x<L=2Kとしておきます。通常の二分探索木の条件に加えて下記を満たすものを「ピボット木(Pivot Tree)」と呼びましょう。

  • 各ノードに「ピボット」という値を設定する
    • 根のピボット値は2K1
    • ピボット値がp(偶数)のノードの左の子のピボット値はplsb(p)/2、右の子のピボット値はp+lsb(p)/2(*)
  • 各ノードについて、
    • 左の子(およびその子孫)の値はピボット値より小さい
    • 右の子(およびその子孫)の値はピボット値より大きい

(*)lsbは最下位ビットを表します。またピボット値が奇数のノードの子のピボット値は参照されないので何でもいいんですが、ここでは未定義としておきます。参照されないというのは具体例のところを見てもらえると良いと思います。

ピボット木は平衡二分木になります。 より具体的には、高さがK=log2Lを超えないことが示せます。

具体例

L=16, K=4としてみましょう。すると、ピボットは次のようになります。

pivot2.png

赤字 はピボットを表します。必ずしもその数しか入らない訳ではありません。
ピボットがaである頂点を、単に頂点aと呼びます。

最初はすべて空欄です。ここから順番に要素を追加していきます。
なお、上の絵ではすべての頂点を最初から描いていますが、実装にあたっては 数が入るところのみノードを追加 すれば良いです。つまり、最初はノードが何もない状態です。

要素の追加

1を追加
pivot_1.png

最初は一番上(頂点8)に追加します。

2を追加
pivot_2.png

12も、頂点8より右には行けません。ここでは小さい方の1を左の子に移します。

3を追加
pivot_3.png

左にずれます。

4を追加
pivot_4.png

さらに左にずれます。2は、頂点2に止まります。2は頂点2に乗ることはできますが、このどちらの子にも移動できないことに注意してください。
なお1は一番下の段(頂点1)に到達しました。ピボット値が奇数の頂点の子のピボットは定義されていませんが、ここにはさらに別の数が降って来ることはないので問題ありません。

5を追加
pivot_5.png

3は頂点2では右の子に移動します。

6を追加
pivot_6.png

5は頂点4では右の子に移動します。
こんな感じで、どんな場合でもK段目より下に行くことはありません。厳密には、各ノードに入りうる整数の範囲を考えると示せます。

要素の削除

ここから削除です。

5を削除
pivot_del5.png

下に子がなければそのまま削除するだけです。

2を削除
pivot_del2.png

左右に子があれば、(通常の二分探索木の削除と同様に)自分より大きいもののうち最小のものを今いる位置に移動させます。このとき、ピボットの条件が崩れることはありません。
左の子だけある場合は、自分より小さいもののうち最大のものを今いる位置に移動させれば良いです。

要素の検索

ある値以上(以下)の最小(最大)の要素を求めるなどです。これは通常の二分探索木と同様に、根から順に、探したい値がノードの値より小さければ左へ、大きければ右へ行くのを繰り返せば、最悪K (=log2L)ステップでたどり着けます。
全体の中での最小値・最大値も同様に求まります。

問題例

AtCoder の過去問を2つほど紹介します。ネタバレを含みますがご了承ください。

問題例1

CPSCO 2019 1-E (Exclusive Or Queries)

当時 Python / PyPy ではきついと言われてみんなで頑張ってたやつです。

その後、いろんな人がいろんな方法で通してましたね。 BIT とセグ木という方法もあったと思います。結果的にとても教育的な問題だったと思っています。(ちなみにてんぷらさんもその後通してました。)
私も平衡二分木の整備のきっかけになったので良かったです。

ACコード →

取りうる整数の範囲が109程度あるので座圧する手もありますが、本稿の方法だとしなくても大丈夫です。
なおこの問題では、「すでに存在する値を追加しようとした場合」に、その値を削除するようにしています。

問題例2

ABC170-E

K=48つまり追加されうる整数の種類が 2481の平衡二分木を20万本ほど使っています(構築自体はO(1)でできるので、このようにたくさん持つこともできます)。48ビットのうち、上位30ビットはメインの整数(問題文でいう「レート」)を表し、下位18ビットは幼児の ID を表します(重複があるとめんどくさいので ID をつけて区別しています)。 区別するためだけにKを増やして定数倍がもったいないと思うかもしれませんが、実際には上位30ビットの時点で要素が区別されるので、ピボット木は30段程度までしか必要ありません。 → よく考えたら元の整数が全部一致してたりするとだめですね(汗)

ACコード →

なおこの問題は検索が最小値(または最大値)のみなので heapq でも実装できるため、平衡二分木はややオーバーキル感もありますが、遅延処理が不要になるため(ライブラリを持っていれば)実装はラクになります。

heapqを用いる方法 →

実装

上に書いたとおり内部的には1x<Lを扱っていますが、実際には0を扱いたいことも多いので、実装では値を1ずらして保持しています。つまり外から見ると0以上L2 (=2K2)以下の整数を扱えるようにしています。
また、要素が1つもないと場合分けがめんどいので、ダミーの根としてinf=L1(ずらし後)を必ず入れるようにしています。

test.py

class BalancingTree:
    def __init__(self, n):
        self.N = n
        self.root = self.node(1<<n, 1<<n)

    def append(self, v):# v を追加(その時点で v はない前提)
        v += 1
        nd = self.root
        while True:
            if v == nd.value:
                # v がすでに存在する場合に何か処理が必要ならここに書く
                return 0
            else:
                mi, ma = min(v, nd.value), max(v, nd.value)
                if mi < nd.pivot:
                    nd.value = ma
                    if nd.left:
                        nd = nd.left
                        v = mi
                    else:
                        p = nd.pivot
                        nd.left = self.node(mi, p - (p&-p)//2)
                        break
                else:
                    nd.value = mi
                    if nd.right:
                        nd = nd.right
                        v = ma
                    else:
                        p = nd.pivot
                        nd.right = self.node(ma, p + (p&-p)//2)
                        break

    def leftmost(self, nd):
        if nd.left: return self.leftmost(nd.left)
        return nd

    def rightmost(self, nd):
        if nd.right: return self.rightmost(nd.right)
        return nd

    def find_l(self, v): # vより真に小さいやつの中での最大値(なければ-1)
        v += 1
        nd = self.root
        prev = 0
        if nd.value < v: prev = nd.value
        while True:
            if v <= nd.value:
                if nd.left:
                    nd = nd.left
                else:
                    return prev - 1
            else:
                prev = nd.value
                if nd.right:
                    nd = nd.right
                else:
                    return prev - 1

    def find_r(self, v): # vより真に大きいやつの中での最小値(なければRoot)
        v += 1
        nd = self.root
        prev = 0
        if nd.value > v: prev = nd.value
        while True:
            if v < nd.value:
                prev = nd.value
                if nd.left:
                    nd = nd.left
                else:
                    return prev - 1
            else:
                if nd.right:
                    nd = nd.right
                else:
                    return prev - 1

    @property
    def max(self):
        return self.find_l((1<<self.N)-1)

    @property
    def min(self):
        return self.find_r(-1)

    def delete(self, v, nd = None, prev = None): # 値がvのノードがあれば削除(なければ何もしない)
        v += 1
        if not nd: nd = self.root
        if not prev: prev = nd
        while v != nd.value:
            prev = nd
            if v <= nd.value:
                if nd.left:
                    nd = nd.left
                else:
                    #####
                    return
            else:
                if nd.right:
                    nd = nd.right
                else:
                    #####
                    return
        if (not nd.left) and (not nd.right):
            if not prev.left:
                prev.right = None
            elif not prev.right:
                prev.left = None
            else:
                if nd.pivot == prev.left.pivot:
                    prev.left = None
                else:
                    prev.right = None

        elif nd.right:
            # print("type A", v)
            nd.value = self.leftmost(nd.right).value
            self.delete(nd.value - 1, nd.right, nd)    
        else:
            # print("type B", v)
            nd.value = self.rightmost(nd.left).value
            self.delete(nd.value - 1, nd.left, nd)

    def __contains__(self, v: int) -> bool:
        return self.find_r(v - 1) == v

    class node:
        def __init__(self, v, p):
            self.value = v
            self.pivot = p
            self.left = None
            self.right = None

    def debug(self):
        def debug_info(nd_):
            return (nd_.value - 1, nd_.pivot - 1, nd_.left.value - 1 if nd_.left else -1, nd_.right.value - 1 if nd_.right else -1)

        def debug_node(nd):
            re = []
            if nd.left:
                re += debug_node(nd.left)
            if nd.value: re.append(debug_info(nd))
            if nd.right:
                re += debug_node(nd.right)
            return re
        print("Debug - root =", self.root.value - 1, debug_node(self.root)[:50])

    def debug_list(self):
        def debug_node(nd):
            re = []
            if nd.left:
                re += debug_node(nd.left)
            if nd.value: re.append(nd.value - 1)
            if nd.right:
                re += debug_node(nd.right)
            return re
        return debug_node(self.root)[:-1]

BT = BalancingTree(5) # 0 ~ 30 までの要素を入れられるピボット木
BT.append(3)
BT.append(20)
BT.append(5)
BT.append(10)
BT.append(13)
BT.append(8)
BT.debug()
BT.delete(20)
BT.debug()
print(BT.find_l(12)) # 10
print(BT.find_r(5)) # 8
print(BT.min) # 3
print(BT.max) # 13
print(3 in BT) # True
print(4 in BT) # False
BT.debug_list()

# 愚直チェック
from random import randrange
BT = BalancingTree(5) # 0 ~ 30 までの要素を入れられるピボット木
S = set()
for _ in range(1000):
    a = randrange(31)
    if randrange(2) == 0:
        BT.append(a)
        S.add(a)

    else:
        BT.delete(a)
        if a in S: S.remove(a)
    if BT.debug_list() != sorted(list(S)):
        print("NG!!")
    # print(BT.debug_list(), sorted(list(S)))
print("END")

その後 solzard さんにいくつか指摘してもらったので修正しました(上のコードにも反映してます)。

さらに chineristAC さんにも指摘もらったので修正しました(反映がめちゃくちゃ遅くなりました)。ランダムチェックも入れてみましたが今度こそちゃんと動いてそうです。

値が実数のとき

実数範囲でも同様のピボット木を作ることはできます。ただし、オーバーフローには注意する必要があります。有理数型を使うなどしてオーバーフローの問題は解決されるかもしれませんが、とても近い範囲にたくさんの要素が集中すると、平衡が保たれなくなってしまいます。
上では、ピボットは1段下りるごとにちょうど半分ずつになるように設定しましたが、特定の位置に要素が固まりやすいことが分かっていれば必ずしもそうする必要はありません。具体的には、実数xの確率分布が与えられると、累積分布関数がぴったり等分される位置に設定すると効率が良いです(でもほとんどの場合は定数倍の差しかないと思います)。


  1. この記事では「平衡」は高さが要素数の対数オーダーで抑えられる、ぐらいの意味で使っています。 

74
54
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
Kiri8128

@Kiri8128(Kiri8128)

Python で競技プログラミングをしています。
Linked from these articles

Comments

@atusify(atu atu)
(Edited)

いつも大変有用な記事ありがとうございます!

なお本稿では簡単のため、同じ値を複数個追加することはない(すでに存在する値を追加しようとした場合は何も起きない)としています。
必要な場合、①各整数の個数を管理するための Dict を使う、②IDを末尾に入れて Unique にする、などによって容易に修正できます。問題例2でも扱います。

上記について、Dictをクラスに持たせておけば要素の存在確認や重複がある場合のDelete処理が早くなり、ID追加するよりコーディングが楽かなと思い(IDが必要なケースもあるかもしれませんが)、実装しました(self.dを導入)。
また、負の数も入れられると便利な場合もあるかと思い、初期化を区間を指定する形にして拡張してみました(self.shiftを導入)。

大変恥ずかしながらdelete処理の実装詳細をまだ理解しておらず、私の追加実装部分に相当誤りがありそうですが、何らかのご参考まで(CPSCO 2019 1-E, abc170-e, abc140-fが通ることを確認しました)。

class BalancingTree3:
    """
    ref: https://qiita.com/Kiri8128/items/6256f8559f0026485d90
    ref: https://atcoder.jp/contests/abc140/tasks/abc140_f
    ref: https://atcoder.jp/contests/cpsco2019-s1/tasks/cpsco2019_s1_e
    ref: https://atcoder.jp/contests/abc140/tasks/abc140_e

    構築はO(1)
    >>> BT = BalancingTree3(-20, 26) # 少なくとも-20 ~ 26の整数を格納できるBT
    >>> BT.max
    -21
    >>> BT.append(3)
    >>> BT.append(20)
    >>> BT.append(5)
    >>> BT.append(10)
    >>> BT.append(13)
    >>> BT.append(8)
    >>> BT.delete(20)
    >>> BT.append(-5)
    >>> BT.append(-20)
    >>> BT.append(25)
    >>> BT.find_l(12)
    10
    >>> BT.find_r(5)
    8
    >>> BT.find_r(25)
    27
    >>> BT.min
    -20
    >>> BT.max 
    25
    >>> 3 in BT
    True
    >>> 4 in BT
    False
    >>> BT.append(26)
    >>> BT.find_r(25)
    26
    >>> BT.find_r(26)
    27
    >>> BT.append(28)
    Traceback (most recent call last):
        ...
    AssertionError: value must be between -20 and 26
    >>> BT.count(3)
    1
    >>> BT.append(3)
    >>> BT.count(3)
    2
    """
    def __init__(self, v_least, v_most):
        self.v_least = v_least
        self.v_most = v_most
        self.shift = 1 - self.v_least # 内部では1~1+(v_most-v_least)を持つ
        self.N = 1 << (v_most-v_least+1).bit_length()
        self.root = self.node(self.N, self.N)
        self.d = defaultdict(int)

    def count(self, v):
        return self.d[v]

    def debug(self):
        def debug_info(nd_):
            return (nd_.value - self.shift, nd_.pivot - self.shift, nd_.left.value - self.shift if nd_.left else -self.shift, nd_.right.value - self.shift if nd_.right else -self.shift)

        def debug_node(nd):
            re = []
            if nd.left:
                re += debug_node(nd.left)
            if nd.value: re.append(debug_info(nd))
            if nd.right:
                re += debug_node(nd.right)
            return re
        print("Debug - root =", self.root.value - self.shift, debug_node(self.root)[:50])

    def debug_list(self):
        def debug_node(nd):
            re = []
            if nd.left:
                re += debug_node(nd.left)
            if nd.value: re.append(nd.value - self.shift)
            if nd.right:
                re += debug_node(nd.right)
            return re
        return debug_node(self.root)[:-1]

    def append(self, v):# v を追加
        assert self.v_least <= v <= self.v_most, f"value must be between {self.v_least} and {self.v_most}"
        iv = v + self.shift # internal value
        nd = self.root
        if self.d[v] > 0:
            self.d[v] += 1
            return 0
        else:
            self.d[v] = 1
        while True:
            if iv == nd.value:
                # iv がすでに存在する場合に何か処理が必要ならここに書く
                return 0
            else:
                mi, ma = min(iv, nd.value), max(iv, nd.value)
                if mi < nd.pivot:
                    nd.value = ma
                    if nd.left:
                        nd = nd.left
                        iv = mi
                    else:
                        p = nd.pivot
                        nd.left = self.node(mi, p - (p&-p)//2)
                        break
                else:
                    nd.value = mi
                    if nd.right:
                        nd = nd.right
                        iv = ma
                    else:
                        p = nd.pivot
                        nd.right = self.node(ma, p + (p&-p)//2)
                        break

    def leftmost(self, nd):
        if nd.left: return self.leftmost(nd.left)
        return nd

    def rightmost(self, nd):
        if nd.right: return self.rightmost(nd.right)
        return nd

    def find_l(self, v): 
        """vより真に小さいやつの中での最大値(なければ最小値-1)

        Args:
            v (int): value

        Returns:
            int: vより真に小さいやつの中での最大値(なければ最小値-1)
        """
        iv = v + self.shift
        nd = self.root
        prev = 0
        if nd.value < iv: prev = nd.value
        while True:
            if iv <= nd.value:
                if nd.left:
                    nd = nd.left
                else:
                    return prev - self.shift
            else:
                prev = nd.value
                if nd.right:
                    nd = nd.right
                else:
                    return prev - self.shift

    def find_r(self, v): 
        """vより真に大きいやつの中での最小値(なければself.v_most+1)

        # 実装追加したことでRootとv_most+1が異なってしまったので、Rootではなくv_most+1を返すようにした。
        # ref: https://atcoder.jp/contests/abc140/tasks/abc140_e

        Args:
            v (int): value

        Returns:
            int: vより真に大きいやつの中での最小値(なければself.v_most+1)
        """
        if v > self.v_most: return self.v_most+1 # 追加した
        iv = v + self.shift
        nd = self.root
        prev = 0
        if nd.value > iv: prev = nd.value
        while True:
            if iv < nd.value:
                prev = nd.value
                if nd.left:
                    nd = nd.left
                else:
                    return min(prev - self.shift, self.v_most+1)
            else:
                if nd.right:
                    nd = nd.right
                else:
                    return min(prev - self.shift, self.v_most+1)

    @property
    def max(self):
        return self.find_l(self.N-self.shift)

    @property
    def min(self):
        return self.find_r(-self.shift)

    def delete(self, v, nd = None, prev = None): # 値がvのノードがあれば削除(なければ何もしない)
        iv = v + self.shift
        if self.d[v] > 1 and nd is None:
            self.d[v] -= 1
            return
        elif not self.d[v] and nd is None:
            return
        else:
            if nd is None:
                del self.d[v]
                nd = self.root
            if prev is None: prev = nd
            while iv != nd.value:
                prev = nd
                if iv <= nd.value:
                    if nd.left:
                        nd = nd.left
                    else:
                        return
                else:
                    if nd.right:
                        nd = nd.right
                    else:
                        return
            if (nd.left is None) and (nd.right is None):
                if prev.left is None:
                    prev.right = None
                elif prev.right is None:
                    prev.left = None
                else:
                    if nd.pivot == prev.left.pivot:
                        prev.left = None
                    else:
                        prev.right = None
            elif nd.right:
                nd.value = self.leftmost(nd.right).value
                self.delete(nd.value - self.shift, nd.right, nd)
            else:
                nd.value = self.rightmost(nd.left).value
                self.delete(nd.value - self.shift, nd.left, nd)

    def __contains__(self, v: int) -> bool:
        # return self.find_r(v - self.shift) == v
        return self.d[v] >= 1

    class node:
        def __init__(self, v, p):
            self.value = v
            self.pivot = p
            self.left = None
            self.right = None

1
@gcqmkm

勉強になりました。
でも、いつもでv += 1の理由が理解できません。
ありがとうございました。

0

Let's comment your feelings that are more than good

Being held Article posting campaign

paiza×Qiita記事投稿キャンペーン「プログラミング問題をやってみて書いたコードを投稿しよう!」

~
View details
74
54

Login to continue?

Login or Sign up with social account

Login or Sign up with your email address