Gunosyデータ分析ブログ

Gunosyで働くデータエンジニアが知見を共有するブログです。

【Edward】MCMCの数学的基礎からStochastic Gradient Langevin Dynamicsの実装まで

こんにちは。初めまして。
データ分析部新入りのmathetake(@mathetake)と申します。

先日個人ブログでこんなエントリを書いた人です:

mathetake.hatenablog.com

そんなこんなでTwitter就活芸人(?)として活動(?)してましたが、これからは真面目に頑張っていこうと思います。



今日はみんな大好きベイズモデリングおいて、事後分布推定に欠かせないアルゴリズム(群)の一つである*1

マルコフ連鎖モンテカルロ法(Markov chain Monte Carlo)

通称MCMCに関するエントリです。より具体的に、

MCMCの意義(§1.)から始め、マルコフ連鎖の数学的な基礎(§2.,3.,4.)、MCMCの代表的なアルゴリズムであるMetropolis-Hastings法(§5.)、その例の1つである*2Langevin Dynamics(§6.)、そして(僕の中で)絶賛大流行中のライブラリEdwardを使ってより発展的(?)なアルゴリズムであるStochastic Gradient Langevin Dynamicsの説明&実装(§7.,8.)していきたいと思います。


今までのデータ分析ブログのエントリと少しテイストが違うかもしれませんが、お楽しみいただけたら幸いです。

細心の注意を払ってはいますが、数学的正しさを保証する記事ではありませんので、詳細が気になる方は§Referenceにある資料を御覧ください。

はじめに

本題に入る前にまず、MCMCの意義について軽く触れておきます。

ベイズ統計におけるモデルパラメータの推定には事後分布からのサンプリングがかかせません。即ち、モデルのパラメータをθΘ,データの集合をDとして、ベイズの定理から得られる次のような確率分布

p(θ | D)=p(D | θ)p(θ)Θp(D | θ)p(θ)dθ

からサンプリングする必要がありますが、一般に分母( 正規化定数 )は解析的に求められず、サンプリングは困難です。

そこを、マルコフ連鎖の収束定理等の数学的基礎理論を使って上手くサンプリングする手法がMCMCで、広く応用されています。*3

ベイズ統計の数学的背景については、僕の個人ブログのエントリを御覧ください

mathetake.hatenablog.com

離散値マルコフ連鎖

X:={X0,X1,X2,,Xt,} を離散な集合{1,2,3,4,..,k}(kN) に値を取る、離散な確率変数の集合とします。確率的な値の時間発展だと考え、確率過程と呼ぶことにします。

(Definition 1.) Xマルコフ連鎖であるとは、全ての{xs}s=0t+1{1,2,,k}に対して

P(Xt+1=xt+1 | Xt=xt,X0=x0)=P(Xt+1=x | Xt=xt)   (1)

が成立する事。

感覚的には、次の時刻における値の分布は現在の値のみで決まり、それ以前の値には影響されない確率過程の事です。

数学的便宜上、次の定義も用意しておきます。

(Definition 2.) マルコフ連鎖X斉時的であるとは、

pi,jt:=P(Xt+1=j | Xt=i)

時間に依存しない事。またこの時 行列pi,j:=pi,jt遷移行列 と呼ぶ。

つまり状態遷移の確率が時間に依らず一定なマルコフ連鎖の事で、状態遷移を表す行列を遷移行列 としています。


ずっと数学の話になってしまいましたが、最後に一つだけMCMCに関わる重要な命題とその系を述べておきます。

(Proposition 3.) Xを斉時的なマルコフ連鎖とする。この時 Xt の分布は 遷移行列pi,jX0の分布(初期分布)により完全に決定される。

(Proof) X0,X1, の分布をπ0,π1,,とし、πjtで確率変数Xtが値j を取る確率とする。マルコフ連鎖の性質(1)より

πj1=i=1kP(X1=j | X0=i)P(X0=i)=i=1kπi0pi,j   (2)

また、

πjt=i=1kP(Xt=j | Xt1=i)P(Xt1=i)=i=1kπit1pi,j   (3)

が成立するので、帰納的に示される □

この命題の系として次が得られます

(Corollary 4.) 初期分布π0 と遷移行列 T:=(pi,j) を与えることで、式(2),(3)により斉時的マルコフ連鎖が得られる。

マルコフ連鎖収束定理とMCMC

あと少しMCMCの説明まで辿りつきます。もう少々数学にお付き合いください。

§1. で紹介したように、サンプリングしたい確率分布π=(πi)が手元(?)にあるとします。

(Definition 5.)遷移行列がT:=(pi,j) により与えられる、斉時的マルコフ連鎖Xπ不変分布に持つとは

i=1kπipi,j=πj   ( j{1,,k})

が成立する事。行列の式で書けば

πT=π

が成立すること。この時πX不変分布と言う。

さて、MCMCの肝となる定理は次のものです

(Theorem 6.(離散値マルコフ連鎖の収束定理))
π を不変分布に持つ斉時的マルコフ連鎖X(i) 非周期的 かつ (ii) 既約*4 である時、マルコフ連鎖は不変分布 π に収束する。即ち


limtP(Xt=i)=πi

が成立する。

ここまで来てやっと、MCMCの定義を与えることができます:

(Definition 7.) MCMC(Markov chain Monte Carlo)とは、サンプリングしたい確率分布π を不変分布とするような既約で非周期的なマルコフ連鎖を構築&サンプリング するアルゴリズムの事。

Theorem 6.により、MCMCにより生成されるサンプルの列{xt}t=0 は確率分布 πからのサンプルに収束し、目的を達成することができます。

既約性と非周期性を満たすようなマルコフ連鎖を構築するのはそんなに難しくはありません、が、サンプリングしたい確率分布π を不変分布とするようなマルコフ連鎖を構築するのは一般に困難です。

そこでよく用いられるのが*5詳細釣り合い条件と呼ばれる不変分布を持つための十分条件です:

(Definition 8.)遷移行列がT:=(pi,j) により与えられる、斉時的マルコフ連鎖X不変分布 π に対して詳細釣り合い条件を満たすとは


πi pi,j=πj pj,i    ( i,j{1,,k})


が成立する事。またこの時、マルコフ連鎖Xπ を不変分布に持つ。

連続値マルコフ連鎖の場合

今まで簡単のため、離散な値を持つマルコフ連鎖について話をしてきましたが、連続な確率変数の話に一般化する事ができます。

連続値( 便宜上Rd値とする )なマルコフ連鎖X={X0,X1,,Xt,}斉時的である時、離散値の場合の推移行列に対応する推移核T:Rd×RdR

P(XtA | Xt1=x)=AT(x,y)dy      (ARd)

を満たすものとして定めます。感覚的には現時刻で値 x を取る時、次の時刻の分布を表す密度関数です。

また、Xが分布πを不変分布に持つとは

Rdπ(x)T(x,y)dx=π(y)

を満たすことであり、詳細釣り合い条件

π(x) T(x,y)=π(y) T(y,x)   (4)

で与えられます。

(Theorem 9.(連続値マルコフ連鎖の収束定理))
π を不変分布に持つ斉時的マルコフ連鎖X(i) 非周期的 かつ (ii) 既約である時、Xtotal variation distanceの意味でπ に収束する。即ち、


limtsupARd|πt(A)π(A)| = 0


が成立する。

Metropolis-Hastings法

MCMCの代表的なアルゴリズム(群)である、Metropolis-Hastings法(以下M-H法)について説明します。

M-H法では、各 xに対して提案分布と呼ばれる確率分布 q(x,y)dyを用意し、採択確率と呼ばれる確率

α(x,y):=min{1,π(y)q(y,x)π(x)q(x,y)}

を準備します。そして推移核 T(x,y)

T(x,y)=α(x,y)q(x,y)+A(x)δx,y

として定義&適当な初期分布π0を与えることで斉時的マルコフ連鎖 XMHを考えます。(ここで δx,yx=yの時に1でそれ以外の時に0を取る関数、A(x) は正規化定数を与える関数。)

(Theorem 10.) 斉時的マルコフ連鎖 XMHπ を定常分布に持つ

(Proof) 定義から

π(x)q(x,y)α(x,y)=π(x)q(x,y)min{1,π(y)q(y,x)π(x)q(x,y)}=min{π(x)q(x,y),π(y)q(y,x)}=π(y)q(y,x)min{π(x)q(x,y)π(y)q(y,x),1}=π(y)q(y,x)α(y,x)

が従うので、クロネッカーのデルタの定義から(4)式が成立し詳細釣り合い条件が満たされる。□

Theorem 10.だけでは保証されない既約性や非周期性を満たすような推移核の具体的な設計は重要な課題ですが、ここではそのような性質が満たされ Theorem 6.(or 7.) が成立すると仮定しましょう。*6

この時マルコフ連鎖XMH からサンプリングする次のようなアルゴリズムを Metropolis-Hastings法と呼びます。

(Metropolis-Hastings法)
(1) 初期分布π0 から x0 をサンプリングする:

x0π0

(2) t=0,1,2,3, に対して、以下を実行する
 (i) 標準一様分布から乱数 u を生成する:

uU(0,1)
 (ii) y を提案分布 q(xt,y) からサンプリングする:
yq(xt,y)
 (ii) 次の式により"次の点" xt+1を決める:
xt+1 := {y     if  u<α(xt,y)xt    otherwise

このアルゴリズムが、実際に上で与えたXMH からサンプリングしている事は明らかでしょう。

§0. で述べたように採択確率の計算には目標の分布π(x))正規化定数が必要ない事に注目して下さい。

Langevin Dynamics

M-H法を実際に実行するためには、提案分布 q(x,y)dy を定義する必要があります。

ここではその例として、Langevin Dynamics法(Metropolis-adjusted Langevin Algorithm)(以下LD法) を紹介します。

LD法では提案分布を次のように定義します:

q(x,y):=N(x+ϵ2logπx(x),  ϵI)

ここでN(μ,Σ) は平均 μ, 分散 Σに従う正規分布で、ϵはstep size(またはlearning rate)と呼ばれるハイパーパラメータです。

ベイズモデルの事後分布に適応する場合において、決定論的に眺めると、LD法は事前分布で正規化した、ノイズ入り勾配降下法のようなものであると解釈することができます。

実際、事後分布推定において、π は観測データを D={di}i=1Nとして

π(θ)=p(θ | D)=p(θ)i=1Np(di | θ)Θp(D | θ)p(θ)dθ

で与えられ、その対数微分は

θlogπ(θ)=θlogp(θ)+θi=1Nlogp(di | θ)   (5)

のように計算できるので、上のような解釈ができます。

実はLG法は、Stanで有名になったHamiltonian Monte Carlo法の特別なケースと等価になっているので、気になる方は [1]や[2] を御覧ください。

Stochastic Gradient Langevin Dynamics(SGLD)

ここまでMCMCの数学的基礎からM-H法、そしてその具体例としてLG法を紹介しました。

LG法の問題点として、(5) 式の計算量がサンプル数が増えるほど膨大になっていく点があります。

近年はビッグデータと呼ばれるバズワードもあるように、サンプル数&パラメータ数が巨大なセッティングでモデリングする事が多いのでこのままLG法を適用する事はできません。*7


その問題点を克服するサンプリング手法として、ここで紹介するのが Stochastic Gradient Langevin Dynamics法 [8](以下SGLD法)です。

SGLD法では次のように初期分布からサンプルしたパラメータを更新していきます:

まず、次の2つの条件

t=0ϵt2< ,      t=0ϵt=

を満たす数列{ϵt}t=0を用意し、パラメータサンプル{θt}t=0を次の式によって取得して行きます

θt+1   θt+ϵt2Lθ(θt)+ηt ,   ηtN(0 ,ϵt) Lθ(θt)=logpθt(θt)+N|St|dStlogp(d | θ)θ(θt)

ここで、St はデータ{di}i=1N からランダムに抽出された N より十分小さいミニバッチとします。

注意として、このアルゴリズムに対応するマルコフ連鎖は斉時的ではないので、上述の収束定理は適用できません。

ですが、例えば [9]で収束性に関する解析がされています。

EdwardでのSGLDの実装

最後に確率モデリング用ライブラリEdwardを用いて、SGLDをベイズ的線形回帰に適用してみようと思います。

Edwardの詳しい使い方は公式チュートリアルまたは次の論文

[1701.03757] Deep Probabilistic Programming

[1610.09787] Edward: A library for probabilistic modeling, inference, and criticism

をご覧ください。また質問等ありましたら@mathetakeまで気軽にリプライorDMください。

まず各種ライブラリをimportします。

import numpy as np
import tensorflow as tf
import edward as ed
from edward.models import Normal, Empirical
import time

次にデータセットを用意します。

N = 20000  # サンプル数
D = 50  # 特徴量の次元
N_ITER = 10000  # MCMCのiteration
MINI_BATCH_SIZE = 2500  #ミニバッチのサイズ 

# toy dataset. 切片=0はなし.
def build_toy_dataset(N, D, noise_std=0.1):
    w = np.random.randn(D).astype(np.float32)
    X = np.random.randn(N, D).astype(np.float32)
    Y = np.dot(X, w) + np.random.normal(0, noise_std, size=N)
    return w, X, Y

# データ生成。観測値のノイズの分散は既知とする。
w_true, X_data, Y_data = build_toy_dataset(N, D)

# ミニバッチを返す関数 
def next_batch(mini_batch_size=128):
    indexes = np.random.randint(N, size=mini_batch_size)
    return X_data[indexes], Y_data[indexes]

モデルを構築し, 推論のためのインスタンスを作ります。

# 観測データを挿入するためのデータを収めるplaceholder
x = tf.placeholder(tf.float32, [MINI_BATCH_SIZE, D])
y_ph = tf.placeholder(tf.float32, [MINI_BATCH_SIZE])

w = Normal(mu=tf.zeros(D), sigma=tf.ones(D))
b = Normal(mu=tf.zeros(1), sigma=tf.ones(1))
y = Normal(mu=ed.dot(x, w) + b, sigma=tf.ones(MINI_BATCH_SIZE)*0.1)

# 経験分布をposteriorの近似に使う
qw = Empirical(params=tf.Variable(tf.random_normal([N_ITER, D])))
qb = Empirical(params=tf.Variable(tf.random_normal([N_ITER, 1])))

# SGLD法用インスタンス
SGLD = ed.SGLD(latent_vars={w: qw, b: qb}, data={y: y_ph})


最後に推論を実行します。

# 推論GO 
# data辞書にはobservedな確率変数の観測データを送る。
# xの値は確率変数ではないので、updateの際feed_dictで送る。
SGLD = ed.SGLD(latent_vars={w: qw, b: qb}, data={y: y_ph})
SGLD.initialize(scale={y: float(N) / MINI_BATCH_SIZE}, step_size=0.00001, n_iter=N_ITER)

start = time.time()
init = tf.global_variables_initializer()
init.run()
for _ in tqdm(range(N_ITER)):
    X_batch, Y_batch = next_batch(MINI_BATCH_SIZE)
    _ = SGLD.update(feed_dict={x: X_batch, y_ph: Y_batch})
elapsed_time = time.time() - start
print("elapsed_time:{}".format(elapsed_time))


実行結果ですが、ミニバッチで勾配計算をしない通常のLD法と比較してみました:

アルゴリズム 実行時間(s) MSE of W Estimated b
SGLD法 26.4 4.9×105 0.0020
LD法 55.4 5.3×107 0.0001

といった感じです。精度はLD法には劣りますが、実行時間は約2倍短いと言った感じです。

より大きなデータセット&複雑なモデルの場合、この差はより顕著になるでしょう。

Appendix

(Definition) 斉時的マルコフ連鎖X非周期的であるとは、任意のi{1,,k} に対して

gcd{n | P(Xn=i | X0=i)>0}=1

が成立する事。

(Definition) 斉時的マルコフ連鎖X既約であるとは, 任意のi,j{1,,k} に対してあるti,jN が存在して

P(Xti,j=j | X0=i)>0

が成立する事。

*1:他には変分ベイズ法などがあります

*2:Hamiltonian Monte Carlo法の一種でもあります。

*3: 事後分布推定に限らず、”正規化定数が分からない分布からサンプリングする手法” として広く使われています。

*4: (i) 非周期的 かつ (ii) 既約 の定義はAppendixに付けておきます。

*5:詳細釣り合い満たさないようなMCMCの研究が最近流行っている、らしいです。例えば 詳細つりあいを満たさないマルコフ連鎖モンテカルロ法とその一般化を御覧ください。

*6:M-Hについて既約性や非周期性が満たされるための条件は論文[5]をご覧ください。

*7:LG法にかぎらず、M-H法は採択確率の計算が基本的にintractableです。ですが、例えばハミルトニアンモンテカルロ法をビッグデータに適用する手法については次の論文があります : Stochastic Gradient Hamiltonian Monte Carlo