Numba Library
競技プログラミングのNumba用スニペットを、作成するたびに追記していく。
基本的には素のPythonとあまり変わらず、せいぜいListをnumpy.ndarrayに置きかえただけのものがほとんどだが、 普通に実装するとNumbaでは使えない機能を踏んでしまうものも一部あり、それを避けた実装としてNumba用にまとまってた方が嬉しい。
- classで作ると事前コンパイルしにくい(できない?)ので、複数の関数を用意する形にする
- クロージャとして実装(
@njitする大枠の関数の中に記述)すれば、各関数を逐一@njitする必要は無い - Numbaのクロージャは再帰ができないので、再帰を用いた実装はなるべく避ける
- どうしても再帰で書かざるを得ない or 再帰の方がアレンジしやすい場合は、単独の関数としてコンパイルする
- その場合は
@njit(型名)を並記する
クラス変数の代替方法
ちょっと長くなるが、Numba用に移植する際の問題点の1つへの対処法に関する考察を書いておく。
Numbaでは、事前コンパイルをする場合はクラスが使えないのが悩ましい。
おかげでインスタンス変数、つまりは状態を持てないので、常に外部から注入する必要がある。
使う際に何を注入するか意識しなければならないし、管理したい変数が増えると記述もどんどん冗長になる。
1 2 3 4 5 6 7 8 9 10 11 12 | class UnionFind: def __init(self, n): self.table= [-1]* n # ←インスタンス変数 def unite(self, a, b): # ...略# 使う際にはtableは隠蔽され、中でどう使われているかなんて気にしなくてよいuft= UnionFind(10)uft.unite(1,5)uft.unite(2,6) |
1 2 3 4 5 6 7 8 9 10 11 12 13 | @njitdef main(): def unite(table, a, b): # 略 def find(table, a, b): # 略 # 使う際は、外部でtableを定義し、毎回連れ回す必要が生じる table= [-1]* 10 unite(table,1,5) unite(table,2,6) |
また、もう1つの問題点として、Numbaは内部関数も含め、nested(入れ子)なリスト等を引数に取れない。Numpyの多次元配列ならOKだが、それでは表現できないものもある。
(上手くいくこともある? 条件調査中)
1 2 3 4 5 6 7 | @njitdef main(): def something_function(nested_list, a, b): # コンパイル時エラー nested_list= [[0]for _in range(10)] something_function(nested_list,1,5) |
Numba関数の中で、nestedなリストを作ることはできる。また、同じ関数内にクロージャ関数を作れば、クロージャ関数から関数外のリストを参照することができる。
これを用いて、以下のようにすれば、something_function の中で nested_list が使える。
1 2 3 4 5 6 7 8 9 | @njitdef main(): nested_list= [[0]for _in range(10)] # 関数より先にnonlocalな変数を定義 def something_function(a, b): nested_list[a][b]= 1 # 変数を使う something_function(1,5) |
しかし、それでは関数と状態が1対1で結びついてしまい、クラスにおける「複数のインスタンスを作る」ようなことができない。
あまり綺麗ではないが、無理矢理解決するとしたら、以下のようになるだろうか。
関数外部にはインスタンス別のリストを記録する NESTED_LIST を用意し、init() ではそこに初期化したリストを加える(これが1つのインスタンス変数となる)。init() は自身のIDを返すので、以降、something_function() など他の関数を呼ぶ際は、そのIDのみを連れ回す。
これなら、管理したい変数が増えても使う側で管理するのはIDのみで済み、極力、内部実装を意識しないで使える。
ちなみに、init()では、入れ子リストの中身が空だと型推定ができず、コンパイルエラーとなってしまう(9~14行目)。
ちょっと奇妙だが、入れる予定の型が分かるような書き方で lst を定義し、それを空にした後コピーするようにすると上手くいく。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | @njitdef main(): NESTED_LIST= [] def init(n): _id= len(NESTED_LIST) # × コンパイルエラー # NESTED_LIST.append([[] for _ in range(n)]) lst= [0] lst.clear() NESTED_LIST.append([lst.copy()for _in range(n)]) return _id def something_function(_id, a, b): nested_list= NESTED_LIST[_id] nested_list[a][b]= 1 id1= something_init(10) id2= something_init(20) something_function(id1,1,5) something_function(id2,2,6) |
あくまでコンパイルが通るというだけで、もっといい書き方があるなら使いたい。
実装例
bit count
2進数表記で'1'の立っている数。
1 2 3 4 5 6 7 | def bit_count(x): x= (x &0x55555555)+ ((x >> 1) &0x55555555) x= (x &0x33333333)+ ((x >> 2) &0x33333333) x= (x &0x0F0F0F0F)+ ((x >> 4) &0x0F0F0F0F) x= (x &0x00FF00FF)+ ((x >> 8) &0x00FF00FF) x= (x &0x0000FFFF)+ ((x >>16) &0x0000FFFF) return x |
bit length
2進数表記の桁数(0は0)
1 2 3 4 5 6 | def bit_length(n): ret= 0 while n >0: n >>= 1 ret+= 1 return ret |
なお、が1以上かを確認するのに while n: としても素のPythonは通るが、Numbaではバージョンにより をbool値だと推定してコンパイルしてしまう。
そうなると、引数にどんな正整数を渡しても となってしまい、おかしくなる。ちゃんと int 型であることがわかるような書き方をする。
mod累乗
を MOD で割った剰余。pythonなら pow(x, a, MOD) で求められるが、Numbaでは第3引数が未対応。二分累乗法で実装。
1 2 3 4 5 6 7 8 9 | def mod_pow(x, a, MOD): ret= 1 cur= x while a >0: if a &1: ret= ret* cur% MOD cur= cur* cur% MOD a >>= 1 return ret |
mod階乗と逆数の事前計算
とそのモジュラ逆数を計算。上記の mod_pow を使用。
前提として、かつ は素数。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | def mod_pow(x, a, MOD): ret= 1 cur= x while a >0: if a &1: ret= ret* cur% MOD cur= cur* cur% MOD a >>= 1 return retdef precompute_factorials(n, MOD): factorials= np.ones(n+ 1, dtype=np.int64) for min range(2, n+ 1): factorials[m]= factorials[m- 1]* m% MOD inversions= np.ones(n+ 1, dtype=np.int64) inversions[n]= mod_pow(factorials[n], MOD- 2, MOD) for min range(n,2,-1): inversions[m- 1]= inversions[m]* m% MOD return factorials, inversions |
Union-Find
union_find_init() で table 配列を生成し、返値をインスタンス番号と見なして各関数に与える。table の根における値は自グループのサイズを表し、結合の際はサイズが小さい方を大きい方の子とする。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | UNIONFIND_TABLE= []def unionfind_init(n): UNIONFIND_TABLE.append(np.full(n,-1, dtype=np.int64)) return len(UNIONFIND_TABLE)- 1def unionfind_getroot(ins, x): table= UNIONFIND_TABLE[ins] stack= [] while table[x] >= 0: stack.append(x) x= table[x] for yin stack: table[y]= x return xdef unionfind_unite(ins, x, y): table= UNIONFIND_TABLE[ins] r1= unionfind_getroot(ins, x) r2= unionfind_getroot(ins, y) if r1== r2: return d1= table[r1] d2= table[r2] if d1 <= d2: table[r2]= r1 table[r1]+= d2 else: table[r1]= r2 table[r2]+= d1def unionfind_find(ins, x, y): return unionfind_getroot(ins, x)== unionfind_getroot(ins, y)def unionfind_getsize(ins, x): table= UNIONFIND_TABLE[ins] return -table[unionfind_getroot(ins, x)] |
Binary Indexed Tree (Fenwick Tree)
fenwick_init() に要素数を与えて初期化し、返値をインスタンス番号と見なして各関数に与える。
は の値を取るものとする(0始まりではない)。
lower_boundは、累積和が 以上になる最小の を返す。
使わない場合は、FENWICK_LOGN および fenwick_init 内でのそれを求める処理は不要(残しても大した計算量ではないが)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | FENWICK_TREE= []FENWICK_LOGN= []def fenwick_init(n): log_n= 0 m= n while m: log_n+= 1 m >>= 1 FENWICK_TREE.append(np.zeros(n+ 1, dtype=np.int64)) FENWICK_LOGN.append(log_n) return len(FENWICK_TREE)- 1def fenwick_add(ins, i, x): arr= FENWICK_TREE[ins] n= arr.size- 1 while i <= n: arr[i]+= x i+= i &-idef fenwick_sum(ins, i): arr= FENWICK_TREE[ins] result= 0 while i >0: result+= arr[i] i ^= i &-i return resultdef fenwick_lower_bound(ins, x): arr= FENWICK_TREE[ins] log_n= FENWICK_LOGN[ins] n= arr.size- 1 sum_= 0 pos= 0 for iin range(log_n,-1,-1): k= pos+ (1 << i) if k < nand sum_+ arr[k] < x: sum_+= arr[k] pos+= 1 << i return pos+ 1 |
外部注入できる Fenwick Tree
単位元と演算を外部から注入する版。
ただし、型があまり自由すぎると扱いきれないので、単位元 identity_element の型は np.int64 型固定とし、演算関数 func も「np.int64型の引数を2つとって、1つ返す関数」固定とする。
func は、add, min, xor など operatorモジュールにあるものはそのまま使えるし、自分で定義したものでもよい。
最大流(Dinic法)
辺に容量 が決められた有向グラフで、頂点 から に流せる最大流量を求める。
二部グラフのマッチングにも使える。
頂点番号は 。
基本は、dinic_init で初期化→dinic_add_links でグラフ生成→dinic_maximum_flow で最大流量を計算。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | DINIC_LINKS= []def dinic_init(n): lst= [[0]] lst.clear() DINIC_LINKS.append([lst.copy()for _in range(n)]) return len(DINIC_LINKS)- 1def dinic_add_link(ins, frm, to, cap): links= DINIC_LINKS[ins] links[frm].append([to, cap,len(links[to])]) links[to].append([frm,0,len(links[frm])- 1])def dinic_bfs(ins, n, s): links= DINIC_LINKS[ins] depth= np.full(n,-1, dtype=np.int64) depth[s]= 0 deq= np.zeros(n+ 5, dtype=np.int64) dl, dr= 0,1 deq[0]= s while dl < dr: v= deq[dl] dl+= 1 for linkin links[v]: if link[1] >0 and depth[link[0]]== -1: depth[link[0]]= depth[v]+ 1 deq[dr]= link[0] dr+= 1 return depthdef dinic_dfs(ins, depth, progress, s, t): links= DINIC_LINKS[ins] stack= [(s,10 ** 18)] flow= 0 while stack: v, f= stack.pop() if v== t: flow= f continue if flow== 0: i= progress[v] if i== len(links[v]): continue progress[v]+= 1 stack.append((v, f)) to, cap, rev= links[v][i] if cap== 0 or depth[v] >= depth[to]: continue stack.append((to,min(f, cap))) else: i= progress[v]- 1 link= links[v][i] link[1]-= flow links[link[0]][link[2]][1]+= flow return flowdef dinic_maximum_flow(ins, n, s, t): flow= 0 while True: depth= dinic_bfs(ins, n, s) if depth[t]== -1: return flow progress= np.zeros(n, dtype=np.int64) path_flow= dinic_dfs(ins, depth, progress, s, t) while path_flow != 0: flow+= path_flow path_flow= dinic_dfs(ins, depth, progress, s, t) |
最小費用流
辺に容量 と、1単位を流したときのコスト が決められた有向グラフで、頂点 から に、流量 を流した時の最小コストを求める。
頂点番号は 。
基本的な使い方は、mincostflow_init で初期化→mincostflow_add_links でグラフ生成→mincostflow_flow で最小費用を計算。
- 実装に当たっての参考
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | from heapqimport heappop, heappushMINCOSTFLOW_LINKS= []INF= 10 ** 10def mincostflow_init(n): """ n: 頂点数 """ lst= [[0]] lst.clear() MINCOSTFLOW_LINKS.append([lst.copy()for _in range(n)]) return len(MINCOSTFLOW_LINKS)- 1def mincostflow_add_link(ins, frm, to, capacity, cost): """ インスタンスID, 辺始点頂点番号, 辺終点頂点番号, 容量, コスト """ links= MINCOSTFLOW_LINKS[ins] links[frm].append([to, capacity, cost,len(links[to])]) links[to].append([frm,0,-cost,len(links[frm])- 1])def mincostflow_flow(ins, s, t, quantity): """ インスタンスID, フロー始点頂点番号, フロー終点頂点番号, 要求流量 """ links= MINCOSTFLOW_LINKS[ins] n= len(links) res= 0 potentials= np.zeros(n, dtype=np.int64) dist= np.full(n, INF, dtype=np.int64) prev_v= np.full(n,-1, dtype=np.int64) prev_e= np.full(n,-1, dtype=np.int64) while quantity: dist.fill(INF) dist[s]= 0 que= [(0, s)] while que: total_cost, v= heappop(que) if dist[v] < total_cost: continue for i, (u, cap, cost, _)in enumerate(links[v]): new_cost= dist[v]+ potentials[v]- potentials[u]+ cost if cap >0 and new_cost < dist[u]: dist[u]= new_cost prev_v[u]= v prev_e[u]= i heappush(que, (new_cost, u)) # Cannot flow quantity if dist[t]== INF: return -1 potentials+= dist cur_flow= quantity v= t while v != s: cur_flow= min(cur_flow, links[prev_v[v]][prev_e[v]][1]) v= prev_v[v] quantity-= cur_flow res+= cur_flow* potentials[t] v= t while v != s: link= links[prev_v[v]][prev_e[v]] link[1]-= cur_flow links[v][link[3]][1]+= cur_flow v= prev_v[v] return res |
留意点
同じNumpy配列同士を演算するとエラー
AtCoderで使われて いる いた過去の Numba 0.48.0 では、同じNumPy配列同士を演算すると(?)エラーになる。(※詳細な条件はちゃんと調べてない)
0.53では修正されているのを確認している。
一方を別の名前の変数で定義してやると大丈夫になる。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | # エラー(配列 x の各値をそれぞれ a 乗するコード)@njitdef mod_pow(x, a): ret= np.ones_like(x) cur= x while a >0: if a &1: ret= ret* cur% MOD cur= cur* cur% MOD # ←エラー a >>= 1 return ret# おっけーdef mod_pow(x, a): ret= np.ones_like(x) cur= x while a >0: if a &1: ret= ret* cur% MOD cur_= cur cur= cur* cur_% MOD a >>= 1 return ret |
整数が0か0以外かの判定はちゃんと書く
Numba 0.57.0 で確認。
Pythonでは、整数が0か0以外かの判定に「if a:」「while a:」などと書いても解釈してくれるが、Numbaではbool値として解釈されてしまうことがある。
その場合、コンパイルされた関数中の は全てbool値となるので、正整数を渡しても強制的に 1 になるなど、おかしくなる。(詳細な条件は不明)
1 2 3 4 5 6 7 8 9 10 11 12 13 | @njitdef main(): def mod_pow(x, a): ret= 1 cur= x while a: # ← ここの記述からか、a は関数全体を通してbool値として扱われる if a &1: ret= ret* cur cur= cur* cur a >>= 1 return ret print(mod_pow(10,5)) # 10^1 として渡ってしまい、10 が返る |
「while a > 0:」など、明示的にint型であることが分かるような書き方をする必要がある。