Numba Library

競技プログラミングのNumba用スニペットを、作成するたびに追記していく。

基本的には素のPythonとあまり変わらず、せいぜいListをnumpy.ndarrayに置きかえただけのものがほとんどだが、 普通に実装するとNumbaでは使えない機能を踏んでしまうものも一部あり、それを避けた実装としてNumba用にまとまってた方が嬉しい。

  • classで作ると事前コンパイルしにくい(できない?)ので、複数の関数を用意する形にする
  • クロージャとして実装(@njit する大枠の関数の中に記述)すれば、各関数を逐一 @njit する必要は無い
  • Numbaのクロージャは再帰ができないので、再帰を用いた実装はなるべく避ける
    • どうしても再帰で書かざるを得ない or 再帰の方がアレンジしやすい場合は、単独の関数としてコンパイルする
      • その場合は @njit(型名) を並記する

クラス変数の代替方法

ちょっと長くなるが、Numba用に移植する際の問題点の1つへの対処法に関する考察を書いておく。

Numbaでは、事前コンパイルをする場合はクラスが使えないのが悩ましい。
おかげでインスタンス変数、つまりは状態を持てないので、常に外部から注入する必要がある。

使う際に何を注入するか意識しなければならないし、管理したい変数が増えると記述もどんどん冗長になる。

素のPythonでのクラスでの実装
1
2
3
4
5
6
7
8
9
10
11
12
class UnionFind:
    def __init(self, n):
        self.table= [-1]* # ←インスタンス変数
     
    def unite(self, a, b):
        # ...略
 
# 使う際にはtableは隠蔽され、中でどう使われているかなんて気にしなくてよい
 
uft= UnionFind(10)
uft.unite(1,5)
uft.unite(2,6)

Numbaでの実装(例)
1
2
3
4
5
6
7
8
9
10
11
12
13
@njit
def main():
    def unite(table, a, b):
        # 略
     
    def find(table, a, b):
        # 略
     
    # 使う際は、外部でtableを定義し、毎回連れ回す必要が生じる
     
    table= [-1]* 10
    unite(table,1,5)
    unite(table,2,6)

また、もう1つの問題点として、Numbaは内部関数も含め、nested(入れ子)なリスト等を引数に取れない。Numpyの多次元配列ならOKだが、それでは表現できないものもある。

(上手くいくこともある? 条件調査中)

入れ子リスト問題
1
2
3
4
5
6
7
@njit
def main():
    def something_function(nested_list, a, b):
        # コンパイル時エラー
     
    nested_list= [[0]for _in range(10)]
    something_function(nested_list,1,5)

Numba関数の中で、nestedなリストを作ることはできる。また、同じ関数内にクロージャ関数を作れば、クロージャ関数から関数外のリストを参照することができる。

これを用いて、以下のようにすれば、something_function の中で nested_list が使える。

入れ子リスト問題解決(暫定)
1
2
3
4
5
6
7
8
9
@njit
def main():
 
    nested_list= [[0]for _in range(10)] # 関数より先にnonlocalな変数を定義
 
    def something_function(a, b):
        nested_list[a][b]= 1               # 変数を使う
     
    something_function(1,5)

しかし、それでは関数と状態が1対1で結びついてしまい、クラスにおける「複数のインスタンスを作る」ようなことができない。

あまり綺麗ではないが、無理矢理解決するとしたら、以下のようになるだろうか。

関数外部にはインスタンス別のリストを記録する NESTED_LIST を用意し、init() ではそこに初期化したリストを加える(これが1つのインスタンス変数となる)。
init() は自身のIDを返すので、以降、something_function() など他の関数を呼ぶ際は、そのIDのみを連れ回す。

これなら、管理したい変数が増えても使う側で管理するのはIDのみで済み、極力、内部実装を意識しないで使える。

ちなみに、init()では、入れ子リストの中身が空だと型推定ができず、コンパイルエラーとなってしまう(9~14行目)。
ちょっと奇妙だが、入れる予定の型が分かるような書き方で lst を定義し、それを空にした後コピーするようにすると上手くいく。

入れ子リスト問題解決
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
@njit
def main():
 
    NESTED_LIST= []
     
    def init(n):
        _id= len(NESTED_LIST)
         
        # × コンパイルエラー
        # NESTED_LIST.append([[] for _ in range(n)])
         
        lst= [0]
        lst.clear()
        NESTED_LIST.append([lst.copy()for _in range(n)])
        return _id
 
    def something_function(_id, a, b):
        nested_list= NESTED_LIST[_id]
        nested_list[a][b]= 1
     
    id1= something_init(10)
    id2= something_init(20)
     
    something_function(id1,1,5)
    something_function(id2,2,6)

あくまでコンパイルが通るというだけで、もっといい書き方があるなら使いたい。

実装例

bit count

2進数表記で'1'の立っている数。

1
2
3
4
5
6
7
def bit_count(x):
    x= (x &0x55555555)+ ((x >> 1) &0x55555555)
    x= (x &0x33333333)+ ((x >> 2) &0x33333333)
    x= (x &0x0F0F0F0F)+ ((x >> 4) &0x0F0F0F0F)
    x= (x &0x00FF00FF)+ ((x >> 8) &0x00FF00FF)
    x= (x &0x0000FFFF)+ ((x >>16) &0x0000FFFF)
    return x

bit length

2進数表記の桁数(0は0)

1
2
3
4
5
6
def bit_length(n):
    ret= 0
    while n >0:
        n >>= 1
        ret+= 1
    return ret

なお、nが1以上かを確認するのに while n: としても素のPythonは通るが、Numbaではバージョンにより nをbool値だと推定してコンパイルしてしまう。
そうなると、引数にどんな正整数を渡しても n=1となってしまい、おかしくなる。ちゃんと int 型であることがわかるような書き方をする。

mod累乗

xaMOD で割った剰余。pythonなら pow(x, a, MOD) で求められるが、Numbaでは第3引数が未対応。二分累乗法で実装。

1
2
3
4
5
6
7
8
9
def mod_pow(x, a, MOD):
    ret= 1
    cur= x
    while a >0:
        if a &1:
            ret= ret* cur% MOD
        cur= cur* cur% MOD
        a >>= 1
    return ret

mod階乗と逆数の事前計算

0!N!とそのモジュラ逆数を計算。上記の mod_pow を使用。
前提として、n<MODかつ MODは素数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def mod_pow(x, a, MOD):
    ret= 1
    cur= x
    while a >0:
        if a &1:
            ret= ret* cur% MOD
        cur= cur* cur% MOD
        a >>= 1
    return ret
 
def precompute_factorials(n, MOD):
    factorials= np.ones(n+ 1, dtype=np.int64)
    for min range(2, n+ 1):
        factorials[m]= factorials[m- 1]* m% MOD
    inversions= np.ones(n+ 1, dtype=np.int64)
    inversions[n]= mod_pow(factorials[n], MOD- 2, MOD)
    for min range(n,2,-1):
        inversions[m- 1]= inversions[m]* m% MOD
    return factorials, inversions

Union-Find

union_find_init()table 配列を生成し、返値をインスタンス番号と見なして各関数に与える。
table の根における値は自グループのサイズを表し、結合の際はサイズが小さい方を大きい方の子とする。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
UNIONFIND_TABLE= []
 
def unionfind_init(n):
    UNIONFIND_TABLE.append(np.full(n,-1, dtype=np.int64))
    return len(UNIONFIND_TABLE)- 1
 
def unionfind_getroot(ins, x):
    table= UNIONFIND_TABLE[ins]
    stack= []
    while table[x] >= 0:
        stack.append(x)
        x= table[x]
    for yin stack:
        table[y]= x
    return x
 
def unionfind_unite(ins, x, y):
    table= UNIONFIND_TABLE[ins]
    r1= unionfind_getroot(ins, x)
    r2= unionfind_getroot(ins, y)
    if r1== r2:
        return
    d1= table[r1]
    d2= table[r2]
    if d1 <= d2:
        table[r2]= r1
        table[r1]+= d2
    else:
        table[r1]= r2
        table[r2]+= d1
 
def unionfind_find(ins, x, y):
    return unionfind_getroot(ins, x)== unionfind_getroot(ins, y)
 
def unionfind_getsize(ins, x):
    table= UNIONFIND_TABLE[ins]
    return -table[unionfind_getroot(ins, x)]

Binary Indexed Tree (Fenwick Tree)

fenwick_init() に要素数を与えて初期化し、返値をインスタンス番号と見なして各関数に与える。

i1Nの値を取るものとする(0始まりではない)。

lower_boundは、累積和が x以上になる最小の iを返す。
使わない場合は、FENWICK_LOGN および fenwick_init 内でのそれを求める処理は不要(残しても大した計算量ではないが)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
FENWICK_TREE= []
FENWICK_LOGN= []
 
def fenwick_init(n):
    log_n= 0
    m= n
    while m:
        log_n+= 1
        m >>= 1
    FENWICK_TREE.append(np.zeros(n+ 1, dtype=np.int64))
    FENWICK_LOGN.append(log_n)
    return len(FENWICK_TREE)- 1
 
def fenwick_add(ins, i, x):
    arr= FENWICK_TREE[ins]
    n= arr.size- 1
    while i <= n:
        arr[i]+= x
        i+= i &-i
 
def fenwick_sum(ins, i):
    arr= FENWICK_TREE[ins]
    result= 0
    while i >0:
        result+= arr[i]
        i ^= i &-i
    return result
 
def fenwick_lower_bound(ins, x):
    arr= FENWICK_TREE[ins]
    log_n= FENWICK_LOGN[ins]
    n= arr.size- 1
    sum_= 0
    pos= 0
    for iin range(log_n,-1,-1):
        k= pos+ (1 << i)
        if k < nand sum_+ arr[k] < x:
            sum_+= arr[k]
            pos+= 1 << i
    return pos+ 1

外部注入できる Fenwick Tree

単位元と演算を外部から注入する版。
ただし、型があまり自由すぎると扱いきれないので、単位元 identity_element の型は np.int64 型固定とし、演算関数 func も「np.int64型の引数を2つとって、1つ返す関数」固定とする。

func は、add, min, xor など operatorモジュールにあるものはそのまま使えるし、自分で定義したものでもよい。

実装

最大流(Dinic法)

辺に容量 capeが決められた有向グラフで、頂点 sから tに流せる最大流量を求める。
二部グラフのマッチングにも使える。

頂点番号は 0N1

基本は、dinic_init で初期化→dinic_add_links でグラフ生成→dinic_maximum_flow で最大流量を計算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
DINIC_LINKS= []
 
def dinic_init(n):
    lst= [[0]]
    lst.clear()
    DINIC_LINKS.append([lst.copy()for _in range(n)])
    return len(DINIC_LINKS)- 1
 
def dinic_add_link(ins, frm, to, cap):
    links= DINIC_LINKS[ins]
    links[frm].append([to, cap,len(links[to])])
    links[to].append([frm,0,len(links[frm])- 1])
 
def dinic_bfs(ins, n, s):
    links= DINIC_LINKS[ins]
    depth= np.full(n,-1, dtype=np.int64)
    depth[s]= 0
    deq= np.zeros(n+ 5, dtype=np.int64)
    dl, dr= 0,1
    deq[0]= s
    while dl < dr:
        v= deq[dl]
        dl+= 1
        for linkin links[v]:
            if link[1] >0 and depth[link[0]]== -1:
                depth[link[0]]= depth[v]+ 1
                deq[dr]= link[0]
                dr+= 1
    return depth
 
def dinic_dfs(ins, depth, progress, s, t):
    links= DINIC_LINKS[ins]
    stack= [(s,10 ** 18)]
    flow= 0
    while stack:
        v, f= stack.pop()
        if v== t:
            flow= f
            continue
        if flow== 0:
            i= progress[v]
            if i== len(links[v]):
                continue
            progress[v]+= 1
            stack.append((v, f))
            to, cap, rev= links[v][i]
            if cap== 0 or depth[v] >= depth[to]:
                continue
            stack.append((to,min(f, cap)))
        else:
            i= progress[v]- 1
            link= links[v][i]
            link[1]-= flow
            links[link[0]][link[2]][1]+= flow
    return flow
 
def dinic_maximum_flow(ins, n, s, t):
    flow= 0
    while True:
        depth= dinic_bfs(ins, n, s)
        if depth[t]== -1:
            return flow
        progress= np.zeros(n, dtype=np.int64)
        path_flow= dinic_dfs(ins, depth, progress, s, t)
        while path_flow != 0:
            flow+= path_flow
            path_flow= dinic_dfs(ins, depth, progress, s, t)

最小費用流

辺に容量 capeと、1単位を流したときのコスト costeが決められた有向グラフで、頂点 sから tに、流量 Qを流した時の最小コストを求める。

頂点番号は 0N1

基本的な使い方は、mincostflow_init で初期化→mincostflow_add_links でグラフ生成→mincostflow_flow で最小費用を計算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from heapqimport heappop, heappush
 
MINCOSTFLOW_LINKS= []
INF= 10 ** 10
 
def mincostflow_init(n):
    """ n: 頂点数 """
    lst= [[0]]
    lst.clear()
    MINCOSTFLOW_LINKS.append([lst.copy()for _in range(n)])
    return len(MINCOSTFLOW_LINKS)- 1
 
def mincostflow_add_link(ins, frm, to, capacity, cost):
    """ インスタンスID, 辺始点頂点番号, 辺終点頂点番号, 容量, コスト """
    links= MINCOSTFLOW_LINKS[ins]
    links[frm].append([to, capacity, cost,len(links[to])])
    links[to].append([frm,0,-cost,len(links[frm])- 1])
 
def mincostflow_flow(ins, s, t, quantity):
    """ インスタンスID, フロー始点頂点番号, フロー終点頂点番号, 要求流量 """
    links= MINCOSTFLOW_LINKS[ins]
    n= len(links)
    res= 0
    potentials= np.zeros(n, dtype=np.int64)
    dist= np.full(n, INF, dtype=np.int64)
    prev_v= np.full(n,-1, dtype=np.int64)
    prev_e= np.full(n,-1, dtype=np.int64)
 
    while quantity:
        dist.fill(INF)
        dist[s]= 0
        que= [(0, s)]
 
        while que:
            total_cost, v= heappop(que)
            if dist[v] < total_cost:
                continue
            for i, (u, cap, cost, _)in enumerate(links[v]):
                new_cost= dist[v]+ potentials[v]- potentials[u]+ cost
                if cap >0 and new_cost < dist[u]:
                    dist[u]= new_cost
                    prev_v[u]= v
                    prev_e[u]= i
                    heappush(que, (new_cost, u))
 
        # Cannot flow quantity
        if dist[t]== INF:
            return -1
 
        potentials+= dist
 
        cur_flow= quantity
        v= t
        while v != s:
            cur_flow= min(cur_flow, links[prev_v[v]][prev_e[v]][1])
            v= prev_v[v]
        quantity-= cur_flow
        res+= cur_flow* potentials[t]
 
        v= t
        while v != s:
            link= links[prev_v[v]][prev_e[v]]
            link[1]-= cur_flow
            links[v][link[3]][1]+= cur_flow
            v= prev_v[v]
 
    return res

留意点

同じNumpy配列同士を演算するとエラー

AtCoderで使われて いる いた過去の Numba 0.48.0 では、同じNumPy配列同士を演算すると(?)エラーになる。(※詳細な条件はちゃんと調べてない)

0.53では修正されているのを確認している。

一方を別の名前の変数で定義してやると大丈夫になる。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# エラー(配列 x の各値をそれぞれ a 乗するコード)
@njit
def mod_pow(x, a):
    ret= np.ones_like(x)
    cur= x
    while a >0:
        if a &1:
            ret= ret* cur% MOD
        cur= cur* cur% MOD # ←エラー
        a >>= 1
    return ret
 
# おっけー
def mod_pow(x, a):
    ret= np.ones_like(x)
    cur= x
    while a >0:
        if a &1:
            ret= ret* cur% MOD
        cur_= cur
        cur= cur* cur_% MOD
        a >>= 1
    return ret

整数が0か0以外かの判定はちゃんと書く

Numba 0.57.0 で確認。

Pythonでは、整数が0か0以外かの判定に「if a:」「while a:」などと書いても解釈してくれるが、Numbaではbool値として解釈されてしまうことがある。

その場合、コンパイルされた関数中の aは全てbool値となるので、正整数を渡しても強制的に 1 になるなど、おかしくなる。(詳細な条件は不明)

1
2
3
4
5
6
7
8
9
10
11
12
13
@njit
def main():
    def mod_pow(x, a):
        ret= 1
        cur= x
        while a:       # ← ここの記述からか、a は関数全体を通してbool値として扱われる
            if a &1:
                ret= ret* cur
            cur= cur* cur
            a >>= 1
        return ret
 
    print(mod_pow(10,5)) # 10^1 として渡ってしまい、10 が返る

while a > 0:」など、明示的にint型であることが分かるような書き方をする必要がある。

programming_algorithm/python_tips/numba_library.txt · 最終更新: 2023/10/16 by ikatakos
CC Attribution 4.0 International
Driven by DokuWikiRecent changes RSS feedValid CSSValid XHTML 1.0