ログイン新規登録

Qiitaにログインしてダークテーマを使ってみませんか?🌙

ログインするとOSの設定にあわせたテーマカラーを使用できます!

13

【internal_math編①】AtCoder Library 解読 〜Pythonでの実装まで〜

最終更新日 投稿日 2020年10月04日

0. はじめに

2020年9月7日にAtCoder公式のアルゴリズム集 AtCoder Library (ACL)が公開されました。
私はACLに収録されているアルゴリズムのほとんどが初見だったのでいい機会だと思い、アルゴリズムの勉強からPythonでの実装までを行いました。

この記事では internal_math をみていきます。

internal_mathはACLの内部で使われる数論的アルゴリズムの詰め合わせで内容は以下の通りです。

名称 概要
safe_mod 整数 x の正整数 m による剰余(x)。ただし 0x を満たす。
barrett 高速な剰余演算。
pow_mod xn(modm) の計算。
is_prime 高速な素数判定。
inv_gcd 整数 a と正整数 b の最大公約数 g および xag(modb) となる x の計算。ただし 0x<bg を満たす。
primitive_root m を法とする原始根。
floor_sum_unsigned 0 以上の整数 n,a,b と自然数 m について i=0n1ai+bm の計算。

本記事ではこれらの内、

  • safe_mod
  • pow_mod
  • inv_gcd
  • barrett

を扱います。なお、constexpr(定数式)自体については触れません。
本記事で扱わない

  • is_prime
  • primitive_root

については以下の記事で扱っています。よろしければそちらもご覧ください。

【internal_math編②】AtCoder Library 解読 〜Pythonでの実装まで〜

なお、floor_sum_unsigned は math 内の floor_sum の補助的な役割なので説明は

【math編】AtCoder Library 解読 〜Pythonでの実装まで〜

で行なっています。

対象としている読者

  • ACLの内部で使われている数学系アルゴリズムを知りたい方。
  • ACLのコードを見てみたけど何をしているのかわからない方。
  • C++はわからないのでPythonで読み進めたい方。

参考にしたもの

C++のドキュメントです。

@drkenさんによる、競技プログラミングで問われるあまりの求め方がまとめられた記事です。非常にわかりやすいです。

ユークリッドの互除法を数式で記述する参考にしました。

Barrett reduction の参考記事です。

1. safe_mod

整数 x の正整数 m による剰余 x を考えます。

1.1. C++での剰余演算

C++では以下のようになります。

#include <bits/stdc++.h>
using namespace std;
int main(){
    cout << " 7 % 3 = " << 7 % 3 << endl;  // 7 % 3 = 1
    cout << "-7 % 3 = " << -7 % 3 << endl;  // -7 % 3 = -1
    
    return 0;
}

注目するのは2つ目で、 x が負のとき剰余も負となります。これは

  • 商は0に向かって丸め込まれる
  • (x/m)m+x を満たす

によるものです。(参考)

1.2. 実装

safe_mod は剰余 x の値を 0x に収める関数です。これは剰余演算の結果が負の場合に m を加算することで達成できます。
Pythonでの実装は以下のようになります。

def safe_mod(x, m):
    x %= m
    if x < 0: x += m
    return x

なお、Pythonでは正整数 m による剰余は0x となることが保証されているのでこの実装は不要です。(これはPythonでの整数除算 "//" が負の無限大に向かって丸め込まれるからです。)

# safe_modによる剰余演算
print(f'safe_mod(7, 3) = {safe_mod(7, 3)}')  # safe_mod(7, 3) = 1
print(f'safe_mod(-7, 3) = {safe_mod(-7, 3)}')  # safe_mod(-7, 3) = 2

# 算術演算子%による剰余演算
print(f' 7 % 3 =  {7 % 3}')  # 7 % 3 = 1
print(f'-7 % 3 =  {-7 % 3}')  # -7 % 3 = 2

2. pow_mod

整数 x および自然数 n,m について xn(modm) の計算を考えます。

2.1. 素朴な方法

指数 n が自然数の場合 xnxn 回かけたものですので、その通りに実装すると以下のようになります。

def naive_pow_mod(x, n, m):
    r = 1  # 結果を入れる変数
    for _ in range(n):
        r *= x
    r %= m  # mで割ったあまりを求める
    return r


x = 3
n = 4
m = 5
print(naive_pow_mod(x, n, m))  # 1

しかし、n109 程度の場合はどうでしょうか。上記の方法では2つの問題点が考えられます。

  1. 109 回程度の乗算が必要になる
  2. 計算の過程で値が非常に大きくなる

これらによって非常に長い計算時間が必要になります。そしてこれらの問題点を解決することで高速に xn(modm) を計算するpow_modが実装できます。

2.2. 問題点1の解決策: 繰り返し二乗法

繰り返し二乗法(バイナリ法とも呼ばれます)はその名の通り二乗を繰り返します。これによって大きな指数の計算を少ない計算回数で行えます。

  • x を二乗して x2 を得る。
  • x2 を二乗して x3を計算することなく x4 を得る。
  • x4 を二乗して x5,x6,x7を計算することなく x8 を得る。

という具合です。具体例として n=50 の場合を見てみましょう。いま、指数が2のべき乗の場合(x,x2,x4,)の結果が得られているので x50 をこれらで表すことを考えます。

50=32+16+2

なので

x50=x32x16x2

とかけます。よって、 x50

  • x32 を求めるまでに5回
  • それらを掛け合わせる3回(結果の値として1を用意した場合)

合計8回の乗算で求められます。
これらの手続きを機械的に行うためには二進数を活用するのが有効です。二進数の定義からいって当たり前ですが、整数 n を構成する2のべき乗(上の例では32, 16, 2)は n を二進数表記したときに'1'であるビットに対応しています。そこで n を下位ビットから見ていき、'1'の場合に対応する x(2);; をかけることで求める値が得られます。
Pythonでの実装は以下のようになります。

# 繰り返し二乗法によるべき乗計算
def binary_pow_mod(x, n, m):
    r = 1  # 結果を入れる変数
    y = x  # x^(2のべき乗) を入れる変数
    while n:
        if n & 1: r = r * y  #最下位ビットが1なら乗算する
        y = y * y
        n >>= 1  #右シフトで次のビットを見る
    r %= m
    return r

2.3. 問題点2の解決策: mod演算の性質

コンピューターは(人間から見て)非常に高速に計算でき、またPythonではメモリが許す限り大きな整数を扱うことができますが、それでもやはり大きな桁の計算は時間がかかります。例えば 31000000000 を繰り返し二乗法を用いて計算する場合、最終的に

3536870912×3463129088

を計算することになります。もし欲しいものが xn を自然数 m で割ったあまりならば、mod演算の性質を利用してこのような大きな数字の演算を回避することができます。使う性質は次のものです。


乗算は(最後にmodを取りさえすれば)いつ何度でもmodをとって良い


xm で割ったあまりは常に 0x の範囲に収まるので乗算をするたびにmodをとることで常に m 未満の値同士の演算にすることができます。

2.4. 実装

pow_modは繰り返し二乗法にmod演算の性質を使ったものです。これは先ほど実装したbinary_pow_modに少し手を加えることで実装できます。

def pow_mod(x, n, m):
    if m == 1: return 0  # 1で割った余りは常に0
    r = 1
    y = x % m  # mで割ったあまりにする
    while n:
        if n & 1: r = r * y % m  # mで割ったあまりにする
        y = y * y % m  # mで割ったあまりにする
        n >>= 1  
    return r

なお、Pythonでは組み込み関数のpow()がpow_modに相当するのでこの実装は不要です。

# 素朴な方法での実装
print(naive_pow_mod(3, 4, 5))  # 1
#print(naive_pow_mod(13, 1000000, 1000000007))  # 終わりません


# 繰り返し二乗法を用いた実装
print(binary_pow_mod(13, 1000000, 1000000007))  # 735092405 このくらいならなんとか計算できます
#print(binary_pow_mod(13, 1000000000, 1000000007))  # 終わりません


# 繰り返し二乗法 + modの性質を用いた実装(ACLのpow_modに相当)
print(pow_mod(13, 1000000000, 1000000007))  # 94858115


# Pythonの組み込み関数powを用いた計算
print(pow(13, 1000000000, 1000000007))  # 94858115

3. inv_gcd

整数 a と正整数 b に対し、

  • 最大公約数 gcd(a,b)
  • xagcd(a,b)(modb) となる x

を計算します。ただし x0x<bgcd(a,b) を満たします。

3.1. ユークリッドの互除法

まずは ab の最大公約数から見ていきましょう。
最大公約数を計算するアルゴリズムとしてユークリッドの互除法というものがあります。
Pythonでの実装は以下のようになります。

def gcd(a, b):
    if b == 0: return a
    return gcd(b, a % b)

以降では a>b とします。a<b であったとしても 0a より再帰的に呼ばれる gcd(b,a は必ず第一引数の方が大きくなるので問題ありません。

最大公約数がユークリッドの互除法によって計算できるのは以下の2つによるものです。

  1. gcd(a,b)=gcd(b,a
  2. a0 に対し gcd(a,0)=a

1.の証明は@drkenさんの記事にありますのでそちらをご覧ください。
2.は a が 0 の約数であることから明らかです。

1.の主張は強力です。いま a>b かつ b>a なので1.によると


ab の最大公約数を求める問題はより小さな数の最大公約数を求める問題に変えることができる


となります。また、a なので1.を繰り返し用いることでいつか必ず gcd(g,0) という形になります。そして 2.よりこの g こそが求める ab の最大公約数です。

3.2. ユークリッドの互除法を数式で記述する

ユークリッドの互除法を数式で表します。
a=r0,b=r1 とし、rkrk+1 で割った商を qk と書くとき

r0=q0r1+r2r1=q1r2+r3rn1=qn1rn+0

となり、このようにして得られる rnab の最大公約数でした。
上の式を行列を用いて表現すると

(r0r1)=(q0110)(r1r2)(r1r2)=(q1110)(r2r3)(rn1rn)=(qn1110)(rn0)

となります。これらをまとめて書くと

(r0r1)=(q0110)(q1110)(qn1110)(rn0)()

が得られます。
いま、i=0,1,,n1 に対し

Qi=(qi110)

とおくと行列式は

det(Qi)=qi011=1

なので、逆行列 Qi1 が存在し

Qi1=(011qi)=(011qi)

です。したがって式()は

(gcd(a,b)0)=(011qn1)(011qn2)(011q0)(ab)()

となります。

3.3. 拡張ユークリッドの互除法

さて、inv_gcdで得られる2つ目のもの

  • xagcd(a,b)(modb) となる x

を見ていきましょう。これは整数 y を用いて

ax+by=gcd(a,b)

と書くこともできます。よって、この式を満たす x を求めればよいことになります。
ところで、前節で得られた式()において

(xyuv)=(011qn1)(011qn2)(011q0)

とおくと

(gcd(a,b)0)=(xyuv)(ab)

となります。つまり、この x,y

ax+by=gcd(a,b)

を満たします。したがって、式()を用いてgcd(a,b)を計算することで、その過程において xagcd(a,b)(modb) となる x が得られることがわかります。このようにして整数 (x,y) を求めるアルゴリズムを拡張ユークリッドの互除法と言います。

3.4. inv_gcd実装の準備①

ユークリッドの互除法は r0=a,r1=b とし、手続き

(ri+1ri+2)=(011qi)(riri+1)

ri+20 になるまで繰り返すことでした。そして

(xiyiuivi)=(011qi)(011qi1)(011q0)

を同時に計算することで ax+by=gcd(a,b) を満たす(x,y)を求めるのが拡張ユークリッドの互除法です。

では実際に計算する過程を見ていきましょう。
a,b が与えられたとき、まず a0 かどうかを確認します。もし 0 ならば、以降の手続きをすることなく

gcd(a,b)=bx=0

であることがわかります。
a,b が与えられたときの各変数の初期状態は以下のようになります。

r0=a,r1=b(x1y1u1v1)=(1001)

(x,y,u,v) の初期状態が単位行列なのは i=0 においてこれらの変数が

(x0y0u0v0)=(011q0)

を満たすためです。
次に各変数の従う漸化式を見ていきます。
まず、ri+1,ri から qi が求まります。

qi=riri+1

次にこれを用いて他の変数の遷移がわかります。

ri+2=riqiri+1(xiyiuivi)=(011qi)(xi1yi1ui1vi1)

終了条件は ri+2=0 です。このとき

ri+1=gcd(a,b)axi+byi=gcd(a,b)

となります。

3.5. inv_gcd実装の準備②

ax+by=gcd(a,b) は不定方程式なので解が無数に存在します。ここでは拡張ユークリッドの互除法によって得られた解 x|x|<bgcd(a,b) を満たすことを示します。
そのためにまず、以下を示します。


i0 に対し、

ri+1|ui|+ri+2|xi|b

が成り立つ。


用いるのはri,xi,ui の漸化式

ri+2=riqiri+1xi=ui1ui=xi1qiui1

と絶対値の性質

|xy||x|+|y|

です。
数学的帰納法により示します。
i=0 のとき、

r1|x0|+r2|u0|=b|1|+r2|0|=bb

で満たします。
i=k のとき、

rk+1|uk|+rk+2|xk|b

を満たすと仮定すると i=k+1 のとき、

rk+2|uk+1|+rk+3|xk+1|=rk+2|xkqk+1uk|+(rk+1qk+1rk+2)|uk|rk+2|xk|+qk+1rk+2|uk|+rk+1|uk|qk+1rk+2|uk|=rk+1|uk|+rk+2|xk|b

となり満たします。
以上より、i0 に対し、

ri+1|ui|+ri+2|xi|b

が成り立つことが示されました。

この結果を使って |x| を評価しましょう。いま、rn+2=0 であったとします。すなわち、

rn+1=gcd(a,b)xnagcd(a,b)(modb)

となります。先ほど示した不等式において i=n1 の場合を考えます。すると、

()=rn|un1|+rn+1|xn1|rn|un1|>rn+1|un1|=gcd(a,b)|xn|

であり、aより gcd(a,b)b なので、拡張ユークリッドの互除法によって得られた xn

|xn|<bgcd(a,b)

を満たすことが示されました。
また、xn<0 であった場合

x=xn+bgcd(a,b)

とすることで 0x<bgcd(a,b) となる x を得ることができます。この x

xa=(xn+bgcd(a,b))a=xna+lcm(a,b)xna(modb)

より確かに求める解になっています。ここで、lcm(a,b)(a,b) の最小公倍数です。

3.6. inv_gcd実装の準備③

inv_gcd実装の準備①では各変数に添字をつけて数列のようにそれぞれの遷移を追ってきましたが、実装上は古い値を保持する必要はないのでよりコンパクトに書くことができます。
ここからはACLで実際に使われている変数名に沿っていきます。

まずは必要な変数を確認しましょう。
ri は三項間漸化式なので変数は二つ必要になります。これらを s,t とします。qi は一つあればよいのでこれを u とします。(xi,yi,ui,vi) ですが、漸化式を見ると(xi,ui)(yi,vi) は独立であることがわかります。いま、欲しいのは x なので (yi,vi) は保持する必要がありません。そこで、(xi,ui)(m0,m1) とします。

準備①でも述べたとおり、a,b が与えられたときまず a を確認します。これが 0 ならば以降の手続きをすることなく

gcd(a,b)=bx=0

で終わりです。よって、以降では a の場合を見ていきます。
初期状態は

s=r1=bt=r2=r0r0r1r1=a%bm0=x0=u1=0m1=u0=x1r0r1u1=1

です。ここから、t=0 となるまで下図の遷移を繰り返します。

inv_gcd_1.png

3.7. 実装

前節で述べた通りに実装していきます。なお、ACLにおいて safe_mod を用いている部分はPythonにおいて同等の機能である算術演算子 % で代用します。

def inv_gcd(a, b):
    a %= b
    if a == 0: return b, 0
    # 初期状態
    s, t = b, a
    m0, m1 = 0, 1
    while t:
        # 遷移の準備
        u = s // t

        # 遷移
        s -= t * u
        m0 -= m1 * u

        # swap
        s, t = t, s
        m0, m1 = m1, m0

    if m0 < 0: m0 += b // s
    return s, m0


print(inv_gcd(3, 5))  # (1, 2)
print(inv_gcd(20, 15))  # (5, 1)

4. barrett

「ある自然数 m が与えられるので、m で割ったあまりを答えよ」という問題ではオーバーフロー対策や、多倍長整数であったとしても巨大な数によって計算速度が低下することを避けるため、何かの演算をするたびに m で割ったあまりをとるということをよくします。あまりを求めるためには割り算をする必要がありますが、割り算は四則演算の中でコンピューターが苦手としている演算であり、他の演算よりも時間がかかってしまいます。Barrett reductionは「ある決まった自然数(定数) m で割ったあまりをとる演算」を高速化するアルゴリズムです。
ACLにおいては 0a<m,0b<m となる整数 a,b に対して

(ab)%m

を高速化する目的で使われています。
以降では z:=ab;(0z<m2) とおき、z を考えます。

4.1. アイデア

いくら割り算が苦手と言っても、あまりを求める際には避けては通れません。なぜならあまりは

z%m=zz÷mm

と表されるからです。また、一般的な数による割り算は苦手でも2のべき乗による割り算は得意です。2進数で動くコンピューターにとって、「2k で割る」ということは「右に k シフトする」だけでよいからです。

よって、あまりを求める演算を高速化するために、苦手な演算(一般的な自然数による割り算)を得意な演算(足し算、引き算、掛け算、シフト演算)に置き換えることを考えます。
m で割る」は「1mをかける」と等価です。いま、自然数 k,n を用いて十分良い精度で

1mn2k

と近似できたとします。すると、

z%m=zz1mm=zzn2km=z{(zn)>>k}m

となり、得意な演算(引き算、掛け算、シフト演算)だけであまりを表すことができました。

4.2. k, n の決め方

では、自然数 k,n はどのように決めれば良いでしょうか。もし m が 2 のべき乗であったなら、

1m=n2k

を満たす自然数 k,n が存在します。よって以降では、m が 2 のべき乗でない場合を考えます。
まず、k が満たすべき条件を見ていきます。いま、2k<m となる k を選んだとすると、n=1 がもっとも良い近似となりますが、これはほとんどの場合良い近似ではありません。よって、k

2k>m

となるように選ぶ必要があります。
k が決まれば n

n2km

となるように選ぶだけです。よって

n=2kmor2km

となります。ACLでは天井関数を採用していますが、一般的には床関数を選ぶことが多いようです。

さて、k の下限はわかりましたが、具体的な値はどう決めれば良いでしょうか。いま、近似の指標として

e=n2k1m

を導入します。具体例として競技プログラミングでよく問われる m=1000000007;;(109+7) の場合を見てみましょう。このとき、k の下限は

kmin=log21000000007=30

です。そこで、k30k について e を計算してみると下図のようになります。

barrett_1.png

ek に対して単調非増加ですので、k は大きいほど良いことがわかりました。ただし、k が大きくなるほど扱う数が大きくなり、その分計算に時間がかかるのでとにかく大きくすれば良いというわけではなさそうです。
ACLでは k=64 としています。これによって n

n=264m

と決まります。


補足1
n は 正確には

n={0(m=1)264m(m2)

となっています。
ACLにおいて入力 a,b は既に m で割ったあまりとなっていることが要求されます。m=1 のとき、a=b=0 より

z%m=z{(zn)>>k}m=001=0

となるので問題ありません。よって以降では m2 とします。

補足2
割り算が入っていると思われるかもしれませんが、いま m は事前に与えられる定数なので n もまた事前計算によって定数となります。

4.3. なぜ k=64 ?

では、なぜ k=64 なのでしょうか。
ひとつの理由として unsigned long long (符号なし8バイト整数)で扱える最大値が 2641 であることが挙げられます。
そしてもうひとつの理由としては、k=64 であれば4バイト整数の入力 a,b と4バイト整数の法 m に対して (ab) を正しく計算できる、というものがあります。
これを示すためには以下のことを示せば良いでしょう。


2m<231 なる整数 m と 0a,b<m なる整数 a,b に対し積 z:=ab が整数 q,r を用いて

z=qm+r(0r<m)

と表されたとき、次の関係式が成立する。

q{(zn)>>64}<q+2

ここで、

n=264m

であり、>>は右シフト演算である。


これが示された場合、

{(zn)>>64}=qorq+1

となります。すなわち z を正確に計算できたか、もしくは正確な値より m だけ小さく計算してしまったかのどちらかになります。よって、もし得られた結果が負の値であった場合には m を加算することで正しい結果が得られます。

それでは証明していきます。

まずは下限から見ていきます。いま、

n2641m

であるから、

zn264zm=q+rm=q+rm=q

したがって、

{(zn)>>64}q

となります。

つづいて、上限を見ていきます。nm0l<m なる整数 l を用いて

nm=264mm=264+l

とかけることを利用すると

zn=(qm+r)n=qnm+nr=264q+(ql+nr)

が得られます。ここで、z<m2 より q<m であるから

ql+nr<m2+nm<m2+264+m=264+m(m+1)<2264

となるので

zn<264q+2264=264(q+2)

したがって、

{(zn)>>64}<q+2

が成り立ちます。

以上より

q{(zn)>>64}<q+2

が示されました。

4.4. 実装

では実装していきます。まずクラスBarrettを作成しコンストラクタを実装します。法 m を確認するためのメソッドも用意しておきます。

class Barrett:
    # @param 1 <= m < 2^31
    def __init__(self, m):
        self._m = m
        self.im = -(-(1<<64) // m) if m > 1 else 0
    
    # mを返すメソッド
    def umod(self):
        return self._m

変数imがこれまでの n に相当します。ACLでは固定長整数の性質を活かした書き方をしていますが、Pythonは多倍長整数なのでif文で場合分けをして処理します。
このように m による割り算が必要な部分を事前計算し定数として保持しておくことで高速化します。

ではメソッド "mul" を実装します。これが a,b に対し (ab) を返すメソッドです。

class Barrett:
    # __init__()
    # umod()

    def mul(self, a, b):
        assert 0 <= a < self._m
        assert 0 <= b < self._m
        z = a * b
        v = z - ((z * self.im)>>64) * self._m
        if v < 0: v += self._m
        return v
    

m = 1000000007
bt = Barrett(m)
a = 12345678
b = 87654321
print(bt.mul(a, b))  # 14799574
print((a * b) % m)  # 14799574

4.5. 実用性は...

察している方もいるかと思いますが、 Pythonでは組み込みの算術演算子 % を素直に使った方が速いです。

巨大な数になれば効果が出てくるかもしれませんが、競技プログラミングで扱うような数では(Pythonでは)必要なさそうです。

5. おわりに

今回はACLの内部で使われるアルゴリズムを見てきました。数式を追うのが大変でしたが、理解が深まりました。

internal_mathのなかで今回触れなかったものについてはinternal_math編②で書いていますので、よろしければそちらもご覧ください。

説明の間違いやバグ、アドバイス等ありましたらお知らせください。

新規登録して、もっと便利にQiitaを使ってみよう

  1. あなたにマッチした記事をお届けします
  2. 便利な情報をあとで効率的に読み返せます
  3. ダークテーマを利用できます
ログインすると使える機能について

コメント

この記事にコメントはありません。

いいね以上の気持ちはコメントで

13

新規登録して、Qiitaをもっと便利に使ってみませんか

この機能を利用するにはログインする必要があります。ログインするとさらに下記の機能が使えます。

  1. ユーザーやタグのフォロー機能であなたにマッチした記事をお届け
  2. ストック機能で便利な情報を後から効率的に読み返せる

ソーシャルアカウントでログイン・新規登録

メールアドレスでログイン・新規登録