ガシンラーニング

マシンラーニング・統計

【Python実装】LDAのトピックをParticle Filter(SMC)で推論

 今回は、LDA(Latent Dirichlet Allocation)の逐次モンテカルロ法(Sequential Monte Calro)であるパーティクルフィルター(Particle Filter)によるトピック推論をPythonで実装しました。

コードは全てgithubに載せています。githubはこちら

 

以下の書籍3.5章とこの書籍が参照している元論文を参考にしました。

Online Inference of Topics with Latent Dirichlet Allocation [Canini 2009]こちら

こちらの書籍はトピックモデルに限らずベイズモデリング推論の良書です。

トピックモデルによる統計的潜在意味解析 (自然言語処理シリーズ)

トピックモデルによる統計的潜在意味解析 (自然言語処理シリーズ)

初めに

膨大な量の文章が日々新たに流れてくる状況で、全ての文章の情報を保存して、学習の度に全データを読み込み、変分ベイズMCMCなどのバッチ学習をすることは非常に効率の悪くなる場面があります。そこで今回はミニバッチ学習でもなく、オンライン学習であるパーティクルフィルターによるメモリ効率の良いトピック推論を紹介したいと思います。

本記事を読むに当たって、LDAの崩壊型ギブズサンプリングによるトピック推論、つまり下の式の意味が分かることが望ましいです。導出過程は参考書籍に載ってます。

(文章di番目の単語がwd,i=vである時の潜在トピックzd,iの条件付き分布)

p(zd,i=k|wd,i=v,zd,i,wd,i,α,β)

nk,vd,i+βvv(nk,vd,i+βv)nd,kd,i+αkk(nd,kd,i+αk)

LDA(Latent Dirichlet Allocation)とは?

LDAの説明をしていませんでした。

LDAとは各文書に潜在的なトピック(意味)があると仮定した、文章の確率的生成モデルです。単語の相対的な出現頻度に注目しており、出現順序は無視しています。

f:id:gashin_learning:20191002115023p:plain

データの生成過程として、①「文章毎のカテゴリカル分布」からトピックが決まる。②「①で決まったトピックのカテゴリカル分布」から単語が決まる。という流れで、単語の数だけこの生成過程を辿ります。カテゴリカル分布のパラメータθd,ϕkは各々ディリクレ分布から確率的に決まります。

Particle Filter パーティクルフィルターとは?

LDAに限らず、一般的な説明をします。

パーティクルフィルター(PariticleFilter)とは、逐次インポータンスサンプリング(Seqential Importance Sampling)リサンプリングを組み合わせた逐次ベイズフィルターのことで、潜在状態変数の事後分布p(z1:t|x1:t)を多数の「パーティクル」と「重み」によって近似します。(z1:t={z1,z2,z3,,zt}とする)

まずインポータンスサンプリングの説明から入ります。数式は参考書籍を基に記載し、適時補足します。

インポータンスサンプリング(Importance Sampling)

p(z1:t|x1:t)が複雑な分布の場合、サンプルが容易な提案分布q(z1:t|x1:t)からのサンプルz1:t(s)q(z1:t|x1:t)(これがパーティクル)を用いて次のように近似します。

p(z1:t|x1:t)=w(z1:t)q(z1:t|x1:t)ps~(z1:t|x1:t)
ps~(z1:t|x1:t)=s=1Sw(z1:t(s))δ(z1:t=z1:t(s))s=1Sw(z1:t(s))
ここで
w(z1:t)=p(z1:t|x1:t)q(z1:t|x1:t)
このw(z1:t)が重みと言われているものです。 

f:id:gashin_learning:20191002122057p:plain
 重みは正規化するため定数をかけても良いです。p(z1:t|x1:t)の数式が分からなければ、p(x1:t)をかけて以下のようにしてもよいです。

w(z1:t)=p(z1:t|x1:t)q(z1:t|x1:t)p(z1:t,x1:t)q(z1:t|x1:t)
同時分布にすればグラフィカルモデルに合わせて分解すれば求まります。

このps~(z1:t|x1:t)を使うと、例えばp(zt+1|x1:t)を求めたいときにも

p(zt+1|x1:t)=p(zt+1|z1:t)p(z1:t|x1:t)dz1:t
p(zt+1|z1:t)ps~(z1:t|x1:t)dz1:t
=s=1Sw(z1:t(s))p(zt+1|z1:t(s))s=1Sw(z1:t(s))
というように近似します。

 逐次インポータンスサンプリングでは、このps~(z1:t|x1:t)からz1:t(s)を一度にサンプルするのではなく、zc(s)(c=1,2,3,,,t)を1つずつ逐次的にサンプルしていきます。

逐次インポータンスサンプリング(Seqential Importance Sampling)

 提案分布q(z1:t|x1:t)は自由に選ぶことができるため、以下のような分解を仮定します。

q(z1:t|x1:t)=q(zt|x1:t,z1:t1)q(zt1|x1:t1,z1:t2)q(z1|x1,z0)q(z0)
=q(z0)ctq(zc|x1:c,z1:c1)
つまり、
q(z1:t|x1:t)=q(zt|x1:t,z1:t1)q(z1:t1|x1:t1)
という提案分布の更新式を得ます。ここで数学的にはq(z1:t|x1:t)=q(zt|x1:t,z1:t1)q(z1:t1|x1:t)となりますが、q(z1:t1|x1:t)=q(z1:t1|x1:t1)と仮定します(提案分布なので自由に選べます。q(z1:t|x1:t)にどのような構造を仮定するかは自由で、どのみち重みの更新式の計算で辻褄を合わせることになります)。提案分布の更新式q(zt|x1:t,z1:t1)になり、初期分布q(z0)tまでの更新式の積としてq(z1:t|x1:t)を得ます。

重みも逐次的に更新します。重みの更新式は

w(z1:t(s))w(z1:t1(s))=p(z1:t(s)|x1:t)p(z1:t1(s)|x1:t1)q(z1:t1(s)|x1:t1)q(z1:t(s)|x1:t)
右辺第一項の分子分母に定数(xの確率分布)をかけて同時分布にします。右辺第二項は提案分布の更新式を使ってます。
p(z1:t(s),x1:t)p(z1:t1(s),x1:t1)1q(zt(s)|x1:t,z1:t1(s))
=p(zt(s),xt|x1:t1,z1:t1(s))q(zt(s)|x1:t,z1:t1(s))
=p(xt|zt(s),x1:t1,z1:t1(s))p(zt(s)|x1:t1,z1:t1(s))q(zt(s)|x1:t,z1:t1(s))
となります。

ちょっと脱線して、ブートストラップフィルター

ここで逐次的な提案分布をq(zt(s)|x1:t,z1:t1(s))=p(zt(s)|x1:t1,z1:t1(s))とすれば、xt(s)の情報が失われているので正確ではありませんが、重み更新式の一部分子分母が相殺されてp(xt|zt(s),x1:t1,z1:t1(s))だけが重み更新式として残ります。これが1次マルコフ状態空間モデルの場合、提案分布の更新式は「システムモデル」となり、重みの更新式は「観測モデル(尤度)」となりますね。これがブートストラップフィルターと呼ばれるパーティクルフィルターの特殊な場合になります。ブートストラップフィルターをパーティクルフィルターとして紹介している和書も多いです。(モンテカルロフィルターもこれと同じ?)

リサンプリング(Seqential Importance Sampling)

SISで逐次的に重みを更新していくと、サンプルサイズが増えるに従い、重みの分散が指数的に増加していきます。リサンプリングによって回避します。リサンプリングとは、重みの大きいパーティクルの生成と重みの小さいパーティクルの消滅を意味します。

f:id:gashin_learning:20191002124112p:plain

リサンプリングの後、元々の各パーティクルの重みは粒子の数として反映されたので、1/Sにリセットされます。

リサンプリングではモンテカルロエラーを加えることになるため、毎回実施せずにEESと呼ばれる偏りの指標をもって、設定した閾値を下回った時に実施します。EES=1s=1S(w¯(z1:t(s)))2

リサンプリング手法も色々提案されていますが、今回のLDA推論では、元論文でも採用されているResidualResamplingにて行います。

(Liu and Chen 1998: "Sequential Monte Carlo Methods for Dynamic Systems"より)

f:id:gashin_learning:20191002124727p:plain
簡単に説明すると、「パーティクル数と正規化した重みとの積」を整数部分と少数部分に分けて、①整数部分の数だけ該当するパーティクルを複製する。②「元々のパーティクルの数」から「①で複製した後のパーティクルの合計数」を引いた数だけ、少数部分に比例した確率で抽出するといった手続きです。

 

一般的なパーティクルフィルターは以下の繰り返しです。

for t in range(data_num):

  1. 逐次的インポータンスサンプリングによりzt(s)S個サンプル
  2. 重み更新
  3. ESS計算。もしESSが閾値以下ならば、リサンプリング

LDAのトピックをParticleFilterで推論

パーティクルフィルターはマルコフ性のある状態空間モデルの潜在変数の推定に使われることが多いですが、今回LDAのパラメータθd,ϕkを周辺化した場合に適用します。1次マルコフ性はないことに注意してください。(そのためパーティクルフィルタの説明でも1次マルコフ性を仮定せずに定式化していました。)

f:id:gashin_learning:20191002142933p:plain

LDAは時系列を仮定していないため文章の順番は適当に決めます。更新日時とかでいいんじゃないでしょうか。

前章までで、逐次的な提案分布は

q(zt|x1:t,z1:t1)
重みの更新式は
w(z1:t(s))w(z1:t1(s))=p(xt|zt(s),x1:t1,z1:t1(s))p(zt(s)|x1:t1,z1:t1(s))q(zt(s)|x1:t,z1:t1(s))
でした。

LDAのパーティクルフィルターでは提案分布をq(zt|x1:t,z1:t1)=p(zt|x1:t,z1:t1)とし、これをLDAに合わせて書き直します。文章×単語の2次元行列を1次元に並べて

p(zd,i|wd,i,z(d,i1),w(d,i1),α,β)
p(wd,i|zd,i=k,z(d,i1),w(d,i1),β)p(zd,i=k|z(d,i1),α)
nk,v(d,i1)+βvv(nk,v(d,i1)+βv)nd,k(d,i1)+αkk(nd,k(d,i1)+αk)
(この式は崩壊型ギブズサンプリングの式と同じ!ただし、x=wとしています。wd,iは文章dのi番目の単語。w(d,i)=(w1:d,wd,1:i)としています。)

重みの更新式は

w(z(d,i)(s))w(z(d,i1)(s))
=p(wd,i|zd,i,z(d,i1),w(d,i1),β)p(zd,i|z(d,i1),α)p(zd,i|wd,i,z(d,i1),w(d,i1),α,β)
=k=1Kp(wd,i|zd,i=k,z(d,i1),w(d,i1),β)p(zd,i=k|z(d,i1),α)
=k=1Knk,v(d,i1)+βvv(nk,v(d,i1)+βv)nd,k(d,i1)+αkk(nd,k(d,i1)+αk)
これも崩壊型ギブズサンプリングの式をzd,iで周辺化しただけであるため計算容易です。

冒頭でパーティクルフィルターによる逐次学習によって、オンラインで更新された分のデータだけを使ってトピックを推論することができますと書きましたが、実は単語毎のトピック頻度nk,v(d,i1)は保持する必要があります。ただこれは「ユニーク単語数×トピック」の行列で表現できるので文章が増えてもそんなに変わりにくいです。代わりにnd,k(d,i1)は保持しなくてよいのが利点です。こちらは「文章数×トピック」なので文章数が容量が直結するからです。

リサンプリングについては、書籍ではリサンプリングを実施しないと書いていますが、元論文では実施しているためResidual Resamplingを実装しました。

元論文では、この後「若返り(rejuvenate)」というステップがあります。これは過去のトピックのサンプルを今時点までの全情報を使い、崩壊型ギブズサンプリングで再度サンプルし直すという手続きです。パーティクルの多様性を保つことが可能と説明していますが、後続論文でパフォーマンスとしてはあまり意味が無いとなっています。文章の潜在変数情報を記憶しておく必要があり、メモリ削減の利点が薄れるため実装には入れていません。 若返りたい人は崩壊型ギブズサンプリングの実装はあるため、文章のトピックを保持するように変えれば付け足し可能です。

実装

パーティクルフィルターはNumpyで実装しました。データ前処理&評価可視化にはscikit-learnを使っています。

パーティクルフィルタのアルゴリズムだけオンライン処理しています。文章の読み込み&ベクトル化の箇所はオンラインにしていないので、ここは用途に合わせてください。

コードは全てGithubに載せています。Githubはこちら

 SIS
# Sequential Importance Sampling

## p(z_k| z_{1:k-1})
n_d_alpha_cond_d_particles = np.sum(n_dk_alpha_cond_d_particles, axis=0)
p_z = n_dk_alpha_cond_d_particles /n_d_alpha_cond_d_particles

## p(w_k| z_{1:k}, w_{1:k-1})
n_vk_beta_cond_k_particles = n_vk_beta_particles[word_id, :, :]
n_k_beta_cond_k_particles = np.sum(n_vk_beta_cond_k_particles, axis=0)
p_w = n_vk_beta_cond_k_particles /n_k_beta_cond_k_particles

## p(z_k| z_{1:k-1})p(w_k| z_{1:k}, w_{1:k-1})
sample_p = p_z * p_w
sample_p_sum = np.sum(sample_p, axis=0)
sample_p /=sample_p_sum

# sampling from multinomial distribution
random_sampling = np.random.uniform(0,1, size=particle_num)
sample_p_cumsum = np.cumsum(sample_p, axis=0)
topic_new = multinomial_particles_numba(topic_new, random_sampling, sample_p_cumsum, particle_num)
topic_particles[idx] = topic_new

# update parameters
n_dk_alpha_cond_d_particles[topic_new, np.arange(100)] += 1.0
n_vk_beta_particles[word_id, topic_new, np.arange(100)] += 1.0

# update weight
w*= sample_p_sum
w /= np.sum(w)

# ESS
ESS = 1/np.sum(w**2)
resampling
# Residual Resampling 
K = particle_num * w
K_int = np.trunc(K)
redisual_p = K - K_int
redisual_p/=np.sum(redisual_p)

K_int = K.astype('int') # main resample
M = particle_num - np.sum(K_int)
residual_resampling_array = np.random.multinomial(n=M, pvals=redisual_p) # residual resample

resampling_idx = K_int + residual_resampling_array
particle_idx = np.repeat(np.arange(particle_num), resampling_idx)
topic_particles = topic_particles[:,particle_idx]
n_vk_beta_particles = n_vk_beta_particles[:,:,particle_idx]
n_dk_alpha_cond_d_particles = n_dk_alpha_cond_d_particles[:,particle_idx]

w = np.ones(particle_num)/particle_num

実験

データセット

データセットはsklearnのfetch_20newsgroupsを使用しています。簡単のため、このうち'rec.sport.baseball', 'talk.religion.misc','comp.graphics', 'sci.space'のみ使用しました。

出現頻度min_df=0.005, max_df=0.1で、かつ名詞/固有名詞のみを抽出しています。そのためトピック数も4に設定しています。

推論について

実際には、最初のデータの20%を使って崩壊型ギブズサンプリングでパラメーターを初期化します。このステップは精度の面で重要とされています。

ワードクラウド 

推論されたトピック毎の頻出ワードを列挙します。

f:id:gashin_learning:20191103162031p:plain

topic0は宇宙。topic1はパソコン。topic2は野球。topic3は宗教と元のニュースと同じトピックを抽出できています。

文章のラベルプロット

元データにどのニュース記事かのラベルがついているため(推論には使用していない)、t-SNEで2次元にして可視化して、予測とラベルの類似度をみました。

f:id:gashin_learning:20191103165515p:plain
色がラベル(どのニュース記事か)。座標が推論トピックの事後確率(4次元)を2次元に圧縮した時の位置。

所々ミスしていますが、一応色のかたまりが見てとれます。
 

高度化に向けて

元論文で提案されているパーティクル毎のトピックを木構造で格納することでメモリ効率を上げる手法は今回実装していないので残課題です。

 最後に

パーティクルフィルターの利点はデータに対して、一回しかサンプル(一回で粒子数サンプルする)しないため、推論に使用し終わった後はデータを保持しておく必要がないところです。

ただ、メモリに乗るくらいの文章量ならバッチ学習で全然構いません。その方が全てのデータを加味して推論できるため、一般的には精度も良いと思います。

場面に合わせた推論手法を選ぶと良いと思います。

Appendix ~ NumbaによるPython高速化~

今回、一部Numbaで高速化しました。

NumbaとはJITコンパイラで、既存のPythonコードを少し手を加えるだけで高速化できます。Forループの処理が爆速になります。https://github.com/numba/numba

最も簡単な使い方は以下のように関数にデコレータをつけるだけです。

# using Numba 
@jit(nopython=True)
def collapsed_gibbs_sampling():
~

# Not using Numba
def collapsed_gibbs_sampling():
~

例えば崩壊型ギブズサンプリングをNumbaで高速化したので、Numba使わない場合と速度比較すると100倍以上速くなっています。

f:id:gashin_learning:20191103173909p:plain

ただ行列やIf文などは現時点で変換が上手くいかないようです。またサポートされていない関数が結構あります。。。axis系だめ。。。。そのためパーティクルフィルタには使いませんでした。

[サポート関数一覧]

http://numba.pydata.org/numba-doc/latest/reference/numpysupported.html

今後、少しずつ使っていきたいです。