今回は、LDA(Latent Dirichlet Allocation)の逐次モンテカルロ法(Sequential Monte Calro)であるパーティクルフィルター(Particle Filter)によるトピック推論をPythonで実装しました。
コードは全てgithubに載せています。githubはこちら
以下の書籍3.5章とこの書籍が参照している元論文を参考にしました。
Online Inference of Topics with Latent Dirichlet Allocation [Canini 2009]こちら
こちらの書籍はトピックモデルに限らずベイズモデリング推論の良書です。
トピックモデルによる統計的潜在意味解析 (自然言語処理シリーズ)
- 作者: 佐藤一誠,奥村学
- 出版社/メーカー: コロナ社
- 発売日: 2015/03/13
- メディア: 単行本(ソフトカバー)
- この商品を含むブログ (5件) を見る
初めに
膨大な量の文章が日々新たに流れてくる状況で、全ての文章の情報を保存して、学習の度に全データを読み込み、変分ベイズやMCMCなどのバッチ学習をすることは非常に効率の悪くなる場面があります。そこで今回はミニバッチ学習でもなく、オンライン学習であるパーティクルフィルターによるメモリ効率の良いトピック推論を紹介したいと思います。
本記事を読むに当たって、LDAの崩壊型ギブズサンプリングによるトピック推論、つまり下の式の意味が分かることが望ましいです。導出過程は参考書籍に載ってます。
(文章の番目の単語が=である時の潜在トピックの条件付き分布)
LDA(Latent Dirichlet Allocation)とは?
LDAの説明をしていませんでした。
LDAとは各文書に潜在的なトピック(意味)があると仮定した、文章の確率的生成モデルです。単語の相対的な出現頻度に注目しており、出現順序は無視しています。
データの生成過程として、①「文章毎のカテゴリカル分布」からトピックが決まる。②「①で決まったトピックのカテゴリカル分布」から単語が決まる。という流れで、単語の数だけこの生成過程を辿ります。カテゴリカル分布のパラメータは各々ディリクレ分布から確率的に決まります。
Particle Filter パーティクルフィルターとは?
LDAに限らず、一般的な説明をします。
パーティクルフィルター(PariticleFilter)とは、逐次インポータンスサンプリング(Seqential Importance Sampling)とリサンプリングを組み合わせた逐次ベイズフィルターのことで、潜在状態変数の事後分布を多数の「パーティクル」と「重み」によって近似します。(とする)
まずインポータンスサンプリングの説明から入ります。数式は参考書籍を基に記載し、適時補足します。
インポータンスサンプリング(Importance Sampling)
が複雑な分布の場合、サンプルが容易な提案分布からのサンプル(これがパーティクル)を用いて次のように近似します。
重みは正規化するため定数をかけても良いです。の数式が分からなければ、をかけて以下のようにしてもよいです。
このを使うと、例えばを求めたいときにも
逐次インポータンスサンプリングでは、このからを一度にサンプルするのではなく、を1つずつ逐次的にサンプルしていきます。
逐次インポータンスサンプリング(Seqential Importance Sampling)
提案分布は自由に選ぶことができるため、以下のような分解を仮定します。
重みも逐次的に更新します。重みの更新式は
ちょっと脱線して、ブートストラップフィルター
ここで逐次的な提案分布をとすれば、の情報が失われているので正確ではありませんが、重み更新式の一部分子分母が相殺されてだけが重み更新式として残ります。これが1次マルコフ状態空間モデルの場合、提案分布の更新式は「システムモデル」となり、重みの更新式は「観測モデル(尤度)」となりますね。これがブートストラップフィルターと呼ばれるパーティクルフィルターの特殊な場合になります。ブートストラップフィルターをパーティクルフィルターとして紹介している和書も多いです。(モンテカルロフィルターもこれと同じ?)
リサンプリング(Seqential Importance Sampling)
SISで逐次的に重みを更新していくと、サンプルサイズが増えるに従い、重みの分散が指数的に増加していきます。リサンプリングによって回避します。リサンプリングとは、重みの大きいパーティクルの生成と重みの小さいパーティクルの消滅を意味します。
リサンプリングの後、元々の各パーティクルの重みは粒子の数として反映されたので、にリセットされます。
リサンプリングではモンテカルロエラーを加えることになるため、毎回実施せずにと呼ばれる偏りの指標をもって、設定した閾値を下回った時に実施します。
リサンプリング手法も色々提案されていますが、今回のLDA推論では、元論文でも採用されているResidualResamplingにて行います。
(Liu and Chen 1998: "Sequential Monte Carlo Methods for Dynamic Systems"より)
簡単に説明すると、「パーティクル数と正規化した重みとの積」を整数部分と少数部分に分けて、①整数部分の数だけ該当するパーティクルを複製する。②「元々のパーティクルの数」から「①で複製した後のパーティクルの合計数」を引いた数だけ、少数部分に比例した確率で抽出するといった手続きです。
一般的なパーティクルフィルターは以下の繰り返しです。
for t in range(data_num):
- 逐次的インポータンスサンプリングによりを個サンプル
- 重み更新
- ESS計算。もしESSが閾値以下ならば、リサンプリング
LDAのトピックをParticleFilterで推論
パーティクルフィルターはマルコフ性のある状態空間モデルの潜在変数の推定に使われることが多いですが、今回LDAのパラメータを周辺化した場合に適用します。1次マルコフ性はないことに注意してください。(そのためパーティクルフィルタの説明でも1次マルコフ性を仮定せずに定式化していました。)
LDAは時系列を仮定していないため文章の順番は適当に決めます。更新日時とかでいいんじゃないでしょうか。
前章までで、逐次的な提案分布は
LDAのパーティクルフィルターでは提案分布をとし、これをLDAに合わせて書き直します。文章×単語の2次元行列を1次元に並べて
重みの更新式は
冒頭でパーティクルフィルターによる逐次学習によって、オンラインで更新された分のデータだけを使ってトピックを推論することができますと書きましたが、実は単語毎のトピック頻度は保持する必要があります。ただこれは「ユニーク単語数×トピック」の行列で表現できるので文章が増えてもそんなに変わりにくいです。代わりには保持しなくてよいのが利点です。こちらは「文章数×トピック」なので文章数が容量が直結するからです。
リサンプリングについては、書籍ではリサンプリングを実施しないと書いていますが、元論文では実施しているためResidual Resamplingを実装しました。
元論文では、この後「若返り(rejuvenate)」というステップがあります。これは過去のトピックのサンプルを今時点までの全情報を使い、崩壊型ギブズサンプリングで再度サンプルし直すという手続きです。パーティクルの多様性を保つことが可能と説明していますが、後続論文でパフォーマンスとしてはあまり意味が無いとなっています。文章の潜在変数情報を記憶しておく必要があり、メモリ削減の利点が薄れるため実装には入れていません。 若返りたい人は崩壊型ギブズサンプリングの実装はあるため、文章のトピックを保持するように変えれば付け足し可能です。
実装
パーティクルフィルターはNumpyで実装しました。データ前処理&評価可視化にはscikit-learnを使っています。
パーティクルフィルタのアルゴリズムだけオンライン処理しています。文章の読み込み&ベクトル化の箇所はオンラインにしていないので、ここは用途に合わせてください。
コードは全てGithubに載せています。Githubはこちら
SIS
# Sequential Importance Sampling ## p(z_k| z_{1:k-1}) n_d_alpha_cond_d_particles = np.sum(n_dk_alpha_cond_d_particles, axis=0) p_z = n_dk_alpha_cond_d_particles /n_d_alpha_cond_d_particles ## p(w_k| z_{1:k}, w_{1:k-1}) n_vk_beta_cond_k_particles = n_vk_beta_particles[word_id, :, :] n_k_beta_cond_k_particles = np.sum(n_vk_beta_cond_k_particles, axis=0) p_w = n_vk_beta_cond_k_particles /n_k_beta_cond_k_particles ## p(z_k| z_{1:k-1})p(w_k| z_{1:k}, w_{1:k-1}) sample_p = p_z * p_w sample_p_sum = np.sum(sample_p, axis=0) sample_p /=sample_p_sum # sampling from multinomial distribution random_sampling = np.random.uniform(0,1, size=particle_num) sample_p_cumsum = np.cumsum(sample_p, axis=0) topic_new = multinomial_particles_numba(topic_new, random_sampling, sample_p_cumsum, particle_num) topic_particles[idx] = topic_new # update parameters n_dk_alpha_cond_d_particles[topic_new, np.arange(100)] += 1.0 n_vk_beta_particles[word_id, topic_new, np.arange(100)] += 1.0 # update weight w*= sample_p_sum w /= np.sum(w) # ESS ESS = 1/np.sum(w**2)
resampling
# Residual Resampling K = particle_num * w K_int = np.trunc(K) redisual_p = K - K_int redisual_p/=np.sum(redisual_p) K_int = K.astype('int') # main resample M = particle_num - np.sum(K_int) residual_resampling_array = np.random.multinomial(n=M, pvals=redisual_p) # residual resample resampling_idx = K_int + residual_resampling_array particle_idx = np.repeat(np.arange(particle_num), resampling_idx) topic_particles = topic_particles[:,particle_idx] n_vk_beta_particles = n_vk_beta_particles[:,:,particle_idx] n_dk_alpha_cond_d_particles = n_dk_alpha_cond_d_particles[:,particle_idx] w = np.ones(particle_num)/particle_num
実験
データセット
データセットはsklearnのfetch_20newsgroupsを使用しています。簡単のため、このうち'rec.sport.baseball', 'talk.religion.misc','comp.graphics', 'sci.space'のみ使用しました。
出現頻度min_df=0.005, max_df=0.1で、かつ名詞/固有名詞のみを抽出しています。そのためトピック数も4に設定しています。
推論について
実際には、最初のデータの20%を使って崩壊型ギブズサンプリングでパラメーターを初期化します。このステップは精度の面で重要とされています。
ワードクラウド
推論されたトピック毎の頻出ワードを列挙します。
topic0は宇宙。topic1はパソコン。topic2は野球。topic3は宗教と元のニュースと同じトピックを抽出できています。
文章のラベルプロット
元データにどのニュース記事かのラベルがついているため(推論には使用していない)、t-SNEで2次元にして可視化して、予測とラベルの類似度をみました。
色がラベル(どのニュース記事か)。座標が推論トピックの事後確率(4次元)を2次元に圧縮した時の位置。
所々ミスしていますが、一応色のかたまりが見てとれます。
高度化に向けて
元論文で提案されているパーティクル毎のトピックを木構造で格納することでメモリ効率を上げる手法は今回実装していないので残課題です。
最後に
パーティクルフィルターの利点はデータに対して、一回しかサンプル(一回で粒子数サンプルする)しないため、推論に使用し終わった後はデータを保持しておく必要がないところです。
ただ、メモリに乗るくらいの文章量ならバッチ学習で全然構いません。その方が全てのデータを加味して推論できるため、一般的には精度も良いと思います。
場面に合わせた推論手法を選ぶと良いと思います。
Appendix ~ NumbaによるPython高速化~
今回、一部Numbaで高速化しました。
NumbaとはJITコンパイラで、既存のPythonコードを少し手を加えるだけで高速化できます。Forループの処理が爆速になります。https://github.com/numba/numba
最も簡単な使い方は以下のように関数にデコレータをつけるだけです。
# using Numba @jit(nopython=True) def collapsed_gibbs_sampling(): ~ # Not using Numba def collapsed_gibbs_sampling(): ~
例えば崩壊型ギブズサンプリングをNumbaで高速化したので、Numba使わない場合と速度比較すると100倍以上速くなっています。
ただ行列やIf文などは現時点で変換が上手くいかないようです。またサポートされていない関数が結構あります。。。axis系だめ。。。。そのためパーティクルフィルタには使いませんでした。
[サポート関数一覧]
http://numba.pydata.org/numba-doc/latest/reference/numpysupported.html
今後、少しずつ使っていきたいです。