けんちょん (drken) の競プロ精進記録

主に精進した記録や小ネタを書いていきます。各種アルゴリズムの詳しい説明などは Qiita (https://qiita.com/drken) の方に書いていきたいと思います。

よくやる二項係数 (nCk mod. p) の求め方 ~ 「逆元」の求め方も ~

1. 最も典型的な二項係数の求め方 (1 ≦ k ≦ n ≦ 107 程度)

競プロをしていると、nCk mod. p を計算する場面にしばしば出くわします。時と場合によって色んな方法が考えられますが、以下のものを頻繁に使用するイメージです。多くの AtCoder のトッププレイヤーたちも使用している形式でそれなりに高速です。5 年前のりんごさんのツイートが 1 つのきっかけとなって広まった印象があります:

使い方としては、最初に一度前処理として COMinit() をしておけば、あとは COM(n, k) 関数を呼べばよいです。

  • 前処理 COMinit(): O(n)
  • クエリ処理 COM(n, k): O(1)
#include <iostream>
using namespace std;

const int MAX = 510000;
const int MOD = 1000000007;

long long fac[MAX], finv[MAX], inv[MAX];

// テーブルを作る前処理
void COMinit() {
    fac[0] = fac[1] = 1;
    finv[0] = finv[1] = 1;
    inv[1] = 1;
    for (int i = 2; i < MAX; i++){
        fac[i] = fac[i - 1] * i % MOD;
        inv[i] = MOD - inv[MOD%i] * (MOD / i) % MOD;
        finv[i] = finv[i - 1] * inv[i] % MOD;
    }
}

// 二項係数計算
long long COM(int n, int k){
    if (n < k) return 0;
    if (n < 0 || k < 0) return 0;
    return fac[n] * (finv[k] * finv[n - k] % MOD) % MOD;
}

int main() {
    // 前処理
    COMinit();

    // 計算例
    cout << COM(100000, 50000) << endl;
}

注意

上記の実装は十分高速ではありますが、inv 配列が確実に不要な場面では、

  • fac 配列を計算する
  • finv[n] を逆元計算によって計算しておく
  • finv[i-1] = finv[i] * i % MOD によって finv 配列を後ろから計算していく

とした方が僅かに速いようです (@CuriousFairy315 さんより)

使用可能場面

  • 1kn107
  • p素数 (上の実装ではさらに pn を仮定している)

使用原理

nCk=n!k!(nk)!=(n!)×(k!)1×((nk)!)1

であることを利用しています。COMinit() で、a! (fac[a]) と (a!)1 (finv[a]) のテーブルを予め作っています。これを作っておくことで、クエリ計算が掛け算のみになって高速になります。

fac[0], fac[1], ..., fac[n-1] の計算が O(n) でできることは難しくない感じです。いわゆる累積和ならぬ累積積をやっている感じです。一方、finv[0], finv[1], ..., finv[n-1] の計算も実は O(n) でできることは驚きです。finv を計算するために、mod. p における 1, 2, ..., n の逆元 inv[1], inv[2], ..., inv[n] を O(n) で求めます。そうすれば、inv の累積積をとることで finv も O(n) で計算できます。

p を素数としたとき、mod. p での逆元計算方法には大きく分けて

  • 拡張 Euclid の互除法
  • Fermat の小定理

とがあります。ともに O(logp) かかりますが、多くの場合、拡張 Euclid の互除法の方が高速に動作します。さて、愚直に 1, 2, ..., n の逆元を計算していては O(nlogp) かかってしまいます。ところが

  • i の逆元を、p % i (これは i より小さいことに注意) の逆元を利用して O(1) で求める

という魔法のようなテクニックがあります。そのテクニックを用いることで、mod. p における 1, 2, ..., n の逆元を O(n) で計算できます。そのことを理解するために、そもそも拡張 Euclid の互除法が何をしていたかを考えます。

拡張 Euclid の互除法による逆元計算

a1mod.p を計算するとはすなわち

ax+py=1

を満たす x を求めたいということになります。Euclid の互除法を適用します。すなわち、pa で割ってみます:

p=qa+r

これを代入すると

ax+(qa+r)y=1ry+a(x+qy)=1

になります。これによって (a,p) に関する問題が、それより数値の小さな (r,a) に関する問題に帰着できました。これを再帰的に解くのが拡張 Euclid の互除法です。具体的には (r,a) に関する小問題を解いて

rs+at=1

と解 (s,t)再帰的に得られたとすると、

y=s,x+qy=tx=tqs,y=s

という風に、元の問題の解を構成できます。下に a1 (mod. m) を求める実装を示します。なお注意点として、

  • 逆元を求める mod. m の m素数でなくても、am が互いに素であればよい
  • 下の拡張 Euclid の互除法自体は ab が互いに素でなくても適用できるが、逆元計算するときの am とは互いに素である必要がある

となっています。

#include <iostream>
using namespace std;

// ax + by = gcd(a, b) となるような (x, y) を求める
// 多くの場合 a と b は互いに素として ax + by = 1 となる (x, y) を求める
long long extGCD(long long a, long long b, long long &x, long long &y) {
    if (b == 0) {
        x = 1;
        y = 0;
        return a;
    }
    long long d = extGCD(b, a%b, y, x); // 再帰的に解く
    y -= a / b * x;
    return d;
}

// 負の数にも対応した mod (a = -11 とかでも OK) 
inline long long mod(long long a, long long m) {
    return (a % m + m) % m;
}

// 逆元計算 (ここでは a と m が互いに素であることが必要)
long long modinv(long long a, long long m) {
    long long x, y;
    extGCD(a, m, x, y);
    return mod(x, m); // 気持ち的には x % m だが、x が負かもしれないので
}

int main() {
    // 計算例
    cout << modinv(3, 7) << endl;
}

改めて、inv[1], inv[2], ..., inv[n] を高速に計算する方法

拡張 Euclid の互除法におけるアイディアを少し変形して実現します。a1 を求めるために同じように

ax+py=1

を満たす x を求めて行きます。pa で割って

p=qa+r

とするのですが、ここから上手いことやります。まず、ax+py=1 の両辺を q 倍して変形していくと、

qax+qpy=q(pr)x+qpy=qrx+p(xqy)=q

となります。拡張 Euclid の互除法では (a,p) に関する問題を (r,a) に関する問題に帰着していたのに対し、今回は (r,p) に関する問題に帰着しています。p を残していることがミソですね。さて、z=xqy とおいて、rx+pz=q を満たす x(z) を求めることができれば万事解決ということになります。

rx+pz=1

を満たす (x,z)(s,t) とすると、rs+pt=1 であり、これを両辺 q 倍することで

r(sq)+p(tq)=q

となるので、x=sq,z=tqrx+pz=q を満たします。

  • s=r1=(p % a)1
  • q=(p/a)

であることに注意すると、

a1(p % a)1×(p/a) (mod. p)

であることが導かれました。実装上は

inv[a] = MOD - inv[MOD % a] * (MOD / i) % MOD;

という風にします。

逆元漸化式のもう 1 つの導出方法

上では拡張 Euclid の互除法を意識した導出を示しましたが、もっと直接的に導くこともできます。takapt さんの記事にもある導出方法です。ma で割ると

m=(m/a)a+(m%a)

で、これを変形することによって導出することができます。この両辺の mod. m をとると、(mod. m) は省略して、

(m/a)a+(m%a)=0

(m/a)+(m%a)a1=0

a1=(m%a)1×(m/a)

という風に簡潔に導出することができます。

2. n がさらに巨大で固定値なとき (1 ≦ n ≦ 109, 1 ≦ k ≦ 107 程度)

n が巨大なときは先程の手が使えません。しかし k が小さければ望みがあります。

nCk=n1×n12×...×nk+1k

を利用して O(k) で計算することができます。さらに n が固定値の場合も多く、そんなときは配列テーブル

com[ k ] = nCk

O(k) で前計算しておくことが有効です。

3. n も k も小さいとき ( 1≦ k ≦ n ≦ 2000 程度、mod. p が素数でなくてもよい)

動的計画法によって nCk のテーブルを生成する方法が有力で、mod の p が素数でなくてもよいのが魅力的です。

const long long MOD = 1000000007;
const int MAX_C = 1000;
long long Com[MAX_C][MAX_C];

void calc_com() {
    memset(Com, 0, sizeof(Com));
    Com[0][0] = 1;
    for (int i = 1; i < MAX_C; ++i) {
        Com[i][0] = 1;
        for (int j = 1; j < MAX_C; ++j) {
            Com[i][j] = (Com[i-1][j-1] + Com[i-1][j]) % MOD;
        }
    }
}

4. さらに

n と k がそれほど小さくなく mod が素数でない場合など、いくらでもイヤな場合は考えられますが、そこまで来たら uwi さんの記事を読めば大抵のことは解決します:

5. 二項係数を用いる問題たち

二項係数を用いる問題たちです。