人間だったら考えて

考えて考えて人間だったら考えて

「ベイズ推論による機械学習入門」を読んだので実験してみた(その2)

この記事は何?

szdr.hatenablog.com
に引き続き、「ベイズ推論による機械学習入門」で紹介されているアルゴリズムを実装・実験していきます。

www.kspub.co.jp


今回は4章の「ポアソン混合モデルにおける推論」で紹介されている、ポアソン混合モデルのためのギブスサンプリングポアソン混合モデルのための変分推論を実装・実験しました。

やりたいこと

以下の図で表されるような二峰性の1次元データを考えます。
f:id:sz_dr:20171209173016p:plain
このデータは、Poi(x|λ1=15)から300個サンプリング、Poi(x|λ2=30)から200個サンプリングした合計500個のサンプルです。

この1次元データをK=2つのクラスタに割り当てる問題をこの記事では考えます。

混合モデルのデータ生成過程を以下のように考えます。

  1. 2つのクラスタの混合比率π=(π1,π2)Tが事前分布p(π)=Dir(π|α)から生成される
  2. それぞれのクラスタkに対する観測モデルのパラメータλkが事前分布p(λk)=Gam(λk|a,b)から生成される
  3. n=1,...,Nに関して、xnに対応するクラスタの割り当てsnp(sn|π)=Cat(sn|π)から生成される
  4. n=1,...,Nに関して、snによって選択されたk番目の確率分布p(xn|λk)=Poi(xn|λk)からデータxnが生成される

グラフィカルモデルで表すと以下の図のようになります。
f:id:sz_dr:20171209182625p:plain:w300

事後分布p(S,λ,π|X)が求まると嬉しいのですが、事後分布の式を解析的に求めるのは難しいのでギブスサンプリングや変分推論といった近似推論が用いられます。

近似推論のアルゴリズムの導出は本を参考にしていただくとして、この記事ではどのように実装するかを紹介していきます。

注意:本もしくは著者のブログ記事(以下参照)を読んでいないと、なぜそんなアルゴリズムが出てくるのかなど意味不明だと思うので、ぜひそちらを参照してください。
machine-learning.hatenablog.com
machine-learning.hatenablog.com


ポアソン混合モデルのためのギブスサンプリング

アルゴリズムは以下のようになります。 (本のP133)

λ,πを初期化
FOR i = 1, ..., MAXITER
 FOR n = 1, ..., N
  ηn,kexp(xnlnλkλk+lnπk) s.t.k=1Kηn,k=1を求める
  snCat(sn|ηn)としてサンプリング
 END FOR
 FOR k = 1, ..., K
  λkGam(λk|a^k=n=1Nsn,kxn+a,b^k=n=1Nsn,k+b)としてサンプリング
 END FOR
 πDir(π|{α^k=n=1Nsn,k+αk}k=1K)としてサンプリング
END FOR
このアルゴリズムを実行すると、sn,λk,πのサンプリング結果が得られます。
サンプリング結果を見ることで、各データのクラスタ割り当て確率・各クラスタポアソン分布のパラメータ・クラスタの混合比率を推論できます。

それでは、このアルゴリズムを実装してみます。

# gibbs sampling (4.2)
def mixture_poisson_gibbs_sampling(X, K, max_iter):
    # X: shape -> (N, 1)
    lmd = np.zeros((K, 1)) + 1  # (1, 1, ..., 1)で初期化
    pi = np.zeros((K, 1)) + 1 / K  # (1/K, 1/K, ..., 1/K)で初期化
    a = 1  # ガンマ分布ハイパーパラメータ
    b = 1  # ガンマ分布ハイパーパラメータ
    alpha = np.zeros((K, 1)) + 1  # ディリクレ分布ハイパーパラメータ
    N = X.shape[0]
    
    sampled_lmd = []
    sampled_pi = []
    sampled_S = []
    for i in range(max_iter):
        # s_nをサンプル
        tmp = X.dot(np.log(lmd).reshape(1, -1)) - lmd.reshape(1, -1) + np.log(pi).reshape(1, -1)
        log_Z = - log_sum_exp(tmp)
        eta = np.exp(tmp + log_Z)
        S = np.zeros((N, K))
        for n in range(N):
            S[n] = multinomial.rvs(n=1, p=eta[n], size=1)
        sampled_S.append(S.copy())
            
        # lmd_kをサンプル
        hat_a = X.T.dot(S).reshape(-1, 1) + a
        hat_b = np.sum(S, axis=0).reshape(-1, 1) + b
        for k in range(K):
            lmd[k] = gamma.rvs(a=hat_a[k], scale=1/hat_b[k])
        sampled_lmd.append(lmd.copy())
        
        # piをサンプル
        hat_alpha = np.sum(S, axis=0).reshape(-1, 1) + alpha
        pi = dirichlet.rvs(hat_alpha.reshape(-1), size=1).reshape(-1, 1)
        sampled_pi.append(pi.copy())
    
    return np.array(sampled_lmd).reshape(-1, K), np.array(sampled_pi).reshape(-1, K), np.array(sampled_S).reshape(-1, N, K)

ηn,kを求める部分で和が1になるように正規化をします。
この部分で何も考えずに指数計算するとオーバーフローしてしまうので、logsumexpテクを使います。logsumexpに関しては混合ガウス分布とlogsumexp - Qiitaなどを参考にしてください。
今回はlogsumexpを以下のように実装してみました。

def log_sum_exp(X):
    # \log(\sum_{i=1}^{N}\exp(x_i))
    max_x = np.max(X, axis=1).reshape(-1, 1)
    return np.log(np.sum(np.exp(X - max_x), axis=1).reshape(-1, 1)) + max_x

100回サンプリング(max_iter=100)を行いました。まずはλ1,λ2のサンプリング結果を下図に示します。
f:id:sz_dr:20171210005007p:plain
λ1のサンプリング結果の平均値は14.852、λ2は29.728でした、結構良い精度が出ています。

次にπのサンプリング結果を下図に示します、π1+π2=1なので、π1の結果を示します。
f:id:sz_dr:20171210005817p:plain
π1のサンプリング結果の平均値は0.590でした。元データはクラスタ1から300個・クラスタ2から200個サンプリングしているので、混合比率は300 / (300 + 200) = 0.6となり、こちらも結構良い精度が出ています。

それでは各データのクラスタ割り当て確率を見てみます。
f:id:sz_dr:20171210010630p:plain
ヒストグラムの各ビン毎に、ビンに含まれるデータのクラスタ割り当て確率の平均値を求め、その値によって赤〜青で塗り分けてみました。
xが小さいほど赤色っぽく、大きいほど青色っぽくなってみます。

k-meansのようなクラスタリングアルゴリズムと違って、各データに対するクラスタ割り当て確率が求まります。
そのため、どれくらい自信を持って各データがそのクラスタに属するかということを評価できます、楽しいですね。

ポアソン混合モデルのための変分推論

アルゴリズムは以下のようになります。(本のP137)

q(λ),q(π)を初期化
FOR i = 1, ..., MAXITER
 FOR n = 1, ..., N
  ηn,kexp(xnlnλkλk+lnπk) s.t.k=1Kηn,k=1を求める
  q(sn)=Cat(sn|ηn)を更新
 END FOR
 FOR k = 1, ..., K
  q(λk)=Gam(λk|a^k=n=1Nsn,kxn+a,b^k=n=1Nsn,k+b)を更新
 END FOR
 q(π)=Dir(π|{α^k=n=1Nsn,k+αk}k=1K)を更新
END FOR

期待値計算がアルゴリズム中に含まれています、これらは以下の式で与えられます。
λk=a^kb^klnλk=ψ(a^k)lnb^klnπk=ψ(α^k)ψ(i=1Kα^i)
ここで、ψ()はディガンマ関数です。

それでは、このアルゴリズムを実装してみます。

# variational inference (4.3)
def mixture_poisson_variational_inference(X, K, max_iter):
    init_a = np.ones((K, 1))
    init_b = np.ones((K, 1))
    init_alpha = np.random.rand(K, 1)
    
    a = init_a.copy()
    b = init_b.copy()
    alpha = init_alpha.copy()
    
    for i in range(max_iter):
        # q(s_n)を更新
        ln_lmd_mean = digamma(a) - np.log(b)
        lmd_mean = a / b
        ln_pi_mean = digamma(alpha) - digamma(np.sum(alpha))
        tmp = X.dot(ln_lmd_mean.reshape(1, -1)) - lmd_mean.reshape(1, -1) + ln_pi_mean.reshape(1, -1)
        log_Z = - log_sum_exp(tmp)
        eta = np.exp(tmp + log_Z)
        
        # q(lmd_k)を更新
        a = X.T.dot(eta).reshape(-1, 1) + init_a
        b = np.sum(eta, axis=0).reshape(-1, 1) + init_b
        
        # q(pi)を更新
        alpha = np.sum(eta, axis=0).reshape(-1, 1) + init_alpha
    
    return a, b, eta, alpha    

100回回してみました(max_iter=100)。
まずは、得られたa,bを用いてq(λk)=Gamma(λk|ak,bk)を図示します。
f:id:sz_dr:20171210014516p:plain
クラスタ1に対応するポアソン分布のパラメータλ1は15くらいで、クラスタ2に対応するポアソン分布のパラメータλ2は29くらいでピークが立っています。

次に、得られたαを用いてq(π)=Dir(π|α^)を図示します。
f:id:sz_dr:20171210015037p:plain
混合比率のピークは0.59くらいで立っていますが、裾がやや広めに思われます。ただ「混合比率は0.59です!」って言うのではなくて、「多分0.59くらいなんじゃないですかねー」と言えるのが面白いです。

最後に、得られたηを用いて各データのクラスタ割り当て確率を見てみます。q(sn)=Cat(sn|ηn)であることから、ηnクラスタ割り当て確率として解釈できます。
f:id:sz_dr:20171210021219p:plain
図示方法はギブスサンプリングでやった時と同じです。ギブスサンプリングの時と同じような結果が得られていますね。

その他

  • 実際に実装してみると、ギブスサンプリングと変分推論の類似点が分かりました
  • 実装してみてようやくアルゴリズムの気持ちが分かりました、手動かすの大事ですね。。。
  • クラスタリング可視化の図を描くのに結構手間取りました、実装を載せておきます(ちょっと汚いですが。。。)
def plot_clustering(prob_mat, bin_num):
    bins = np.linspace(np.min(X), np.max(X), num=bin_num)
    X_inds = np.digitize(X, bins)
    norms = np.zeros(bin_num + 1)
    cnts = np.zeros(bin_num + 1)
    for x_ind, prob_n in zip(X_inds[:, 0], prob_mat):
        norms[x_ind - 1] += prob_n[0]
        cnts[x_ind - 1] += 1
    for i in range(bin_num + 1):
        if cnts[i] != 0:
            norms[i] /= cnts[i]
    plt.figure(figsize=(8, 4))
    _, _, patches = plt.hist(X, bins=bins)
    cm = matplotlib.cm.get_cmap("coolwarm")
    colors = [cm(norm) for norm in norms]
    for patch, color in zip(patches, colors):
        patch.set_fc(color)
    
    plt.show()