<p><a href="http://statmodeling.hatenablog.com/archive">StatModeling Memorandum</a></p>

StatModeling Memorandum

StanとRとPythonでベイズ統計モデリングします. たまに書評.

PythonのSymPyで変分ベイズの例題を理解する

この記事の続きです。

ここではPRMLの10.1.3項の一変数ガウス分布の例題(WikipediaVariational_Bayesian_methodsのA basic exampleと同じ)をSymPyで解きます。すなわちデータが

 YnNormal(μ,τ1)  n=1,..,N

に従い*1μτが、

 μNormal(μ0,(λ0τ)1)

 τGamma(a0,b0)

に従うという状況です。ここでデータYnn=1,...,N)が得られたとして事後分布p(μ,τ|Y)を変分ベイズで求めます。

まずはじめに、上記の確率モデルから同時分布p(Y,μ,τ)を書き下しておきます。

 p(Y,μ,τ)=p(Y|μ,τ)p(μ|τ)p(τ)

なので、

 p(Y,μ,τ)=n=1NNormal(Yn|μ,τ1)Normal(μ|μ0,(λ0τ)1)Gamma(τ|a0,b0)

となります。

この問題は単純なので事後分布は厳密に求まるのですが、ここでは変分ベイズで解きます。すなわち、事後分布p(μ,τ|Y)q(μ,τ)で近似します。さらにq(μ,τ)=q(μ)q(τ)と因子分解可能と仮定します。そして、前の記事の最後の2つの式を使って、q(μ)q(τ)が収束するまで繰り返し交互に更新して求めるのでした。以下ではこれをSymPyでやります。

from sympy import *
from sympy.stats import *
init_printing(use_unicode=True)

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'
  • 3行目: 僕は基本的にはJupyter Notebookで実行しています。この行を追加することで、数式がMathJaxで綺麗に表示されます。
  • 5~6行目: セルの途中で出力しても数式が綺麗に表示されるようにしています。こちらの記事を参考にしました。
y, mu, mu0 = symbols('y mu mu0', real=True)
Y_vec = symbols('Y1:4', real=True)
tau, lambda0, a0, b0 = symbols('tau lambda0 a0 b0', positive=True)
  • 1行目: SymPyで使う変数はsymbols関数で作成しておく必要があります。real=Trueと指定することで、実数と仮定することができます。何も指定しなければ複素数になります。このように仮定を入れておかないと、のちの式変形や積分がうまくいかない場合があります。
  • 2行目: このように変数のリストを作成することもできます。
  • 3行目: positive=Trueと指定することで、正の実数だと仮定することができます。

なお、SymPyでは要素数やデータ数をNとするような一般の場合の式変形は基本的に難しいです。しかし具体的な値に決めれば実行できます。そこで、ここでは2行目でデータ数をY1,Y2,Y33個として先に進めます。あとで値を色々変えて試すとNの場合の見当がつくので、そこから一般の場合を証明することもできます。

p_y = density(Normal('', mu, 1/sqrt(tau)))(y)
p_mu = density(Normal('', mu0, 1/sqrt(lambda0*tau)))(mu)
p_tau = density(Gamma('', a0, 1/b0))(tau)

sympy.statsには確率分布の密度関数の式がありますので、それを使っています。ここではデータ1つあたりのyの分布とmutauの事前分布を定義しています。

SymPyの正規分布Normal(平均, 標準偏差)なので、精度であるtau1/sqrt(tau)を代入しています。また、PRMLWikipediaのガンマ分布はGamma(shape, rate)である一方*2、SymPyのガンマ分布はGamma(shape, scale)なので、1/b0を代入しています。

integrate(p_mu, (mu, -oo, oo))
simplify(integrate(p_tau, (tau, 0, oo)))

試しにmuの分布をからまで積分してみましょう。期待通り1が返ります。tauの分布でも同様に積分すると1にならずに整理されていない式が返ってきますが、simplify関数で整理すると1になります。

同時分布の対数(log p)の準備

前の記事の最後の2つの式でやっていることを日本語で書くと以下です。

  • 同時分布の対数log p(Y,μ,τ)q(τ)を掛けてτ積分して、μの分布q(μ)を求める。
  • 同時分布の対数log p(Y,μ,τ)q(μ)を掛けてμ積分して、τの分布q(τ)を求める。

そこでまず同時分布の対数を準備します。

log_p = sum([log(p_y.subs(y, x)) for x in Y_vec]) + log(p_mu) + log(p_tau)
log_p = simplify(log_p)
log_p

 Y12τ2+Y1μτY22τ2+Y2μτY32τ2+Y3μτ+a0log(b0)+a0log(τ)b0τλ0τ2μ2+λ0μμ0τλ0τ2μ023τ2μ2+log(a0)+12log(λ0)+log(τ)log(Γ(a0+1))2log(π)2log(2)

  • 1行目: expr.subs(y, x)は式expryxを代入します。

次に積分にすすみます。

μを含まない項にμを含まない分布q(τ)を掛けてτ積分したところで、やはりμに関係がない定数になります。定数は最後に規格化して求めればよいので、途中の計算はなるべく簡単になるように余計な項を取り除きます。これがSymPyで計算をうまくさせるポイントになります。

log_p_for_mu = integrate(diff(log_p, mu), mu)
log_p_for_mu = collect(log_p_for_mu, mu)
log_p_for_mu

 μ2(λ0τ23τ2)+μ(Y1τ+Y2τ+Y3τ+λ0μ0τ)

  • 1行目: log_pmu微分してmu積分することで、muを含まない項を取り除いています。
  • 2行目: collect関数はmuの関数として式をみたときに共通部分をくくります。
log_p_for_tau = integrate(diff(log_p, tau), tau)
log_p_for_tau = collect(log_p_for_tau, tau)
log_p_for_tau

 τ(Y122+Y1μY222+Y2μY322+Y3μb0λ0μ22+λ0μμ0λ0μ0223μ22)+(a0+1)log(τ)

μ積分してq(τ)を求める方も同様なのでそうしておきます。

できるところまで解析的に求める

SymPyの練習のため、事前分布から積分を1回実行してq(μ)q(τ)を求めるところをやってみます。

log_q1_mu = integrate(log_p_for_mu * p_tau, (tau, 0, oo))
log_q1_mu
log_q1_mu = simplify(log_q1_mu)
log_q1_mu
log_q1_mu = collect(expand(log_q1_mu), mu)
log_q1_mu

 μ2(a0λ02b03a02b0)+μ(Y1a0b0+Y2a0b0+Y3a0b0+a0μ0b0λ0)

  • 5行目: 3行目でsimplifyしていますが、解析者が意図しない形になることはよくあります。ここでは、expandしてcollectすることでmu多項式にしています。

log_q1_muの式はμの二次関数のマイナスなので、このすぐあとのq1_mu正規分布になることが分かります。共役事前分布を使っているからです。規格化定数をもとめて規格化しましょう。

q1_mu = exp(log_q1_mu)
const = simplify(integrate(q1_mu, (mu, -oo, oo)))
const
q1_mu = 1/const * exp(log_q1_mu)
q1_mu

constが規格化定数になります。以下の部分です。

 2πb0a0λ0+3ea0(Y1+Y2+Y3+λ0μ0)22b0(λ0+3)

q1_muは規格化された分布のq(μ)です。以下になります。

 2a0λ0+32πb0ea0(Y1+Y2+Y3+λ0μ0)22b0(λ0+3)eμ2(a0λ02b03a02b0)+μ(Y1a0b0+Y2a0b0+Y3a0b0+a0μ0b0λ0)

同じようにq1_tauを求めます。変分ベイズの手順としては、上で求めたばかりのq(μ)を掛けてμ積分します。しかしSymPyではその計算は重くて実行できないので、ここではμの事前分布p_muを使ってq1_tauを求めてみます。

log_q1_tau = integrate(log_p_for_tau * p_mu, (mu, -oo, oo))
log_q1_tau
log_q1_tau = integrate(diff(log_q1_tau, tau), tau)
log_q1_tau
log_q1_tau = collect(log_q1_tau, tau)
log_q1_tau
  • 3行目: あとで規格化定数を求めればよいので定数項は取り除いておきます。

このすぐあとのq1_tauはガンマ分布になることが分かります。これも共役事前分布を使っているからです。

q1_tau = logcombine(exp(log_q1_tau))
q1_tau
# const = integrate(q1_tau, (tau, 0, oo))
# const
# q1_tau = 1/const * q1_tau
# q1_tau
  • 1行目: logcombine関数を使うことでexp(log(x))xにします。simplify関数だとこの変形をやってくれないことがあります。
  • 3行目: これで素直に積分できればよいのですが、残念ながらできません。

q1_tauは以下です。

 τa0+1eτ(Y122+Y1μ0Y222+Y2μ0Y322+Y3μ0b03μ022)

このexpの肩にのっているτの係数が負だとSymPyが分からないから積分できないようです。ちなみにこのあたりはMathematicaの方が圧倒的に賢くて、例えば以下の入力できちんと積分できます。

Integrate[tau^(a+1)*Exp[tau * (-1/2*x^2 + x*mu - 1/2* y^2 + y*mu - b - mu^2)], {tau, 0, Infinity}, Assumptions -> {b > 0, a > 0, Element[x, Reals], Element[y, Reals], Element[mu, Reals]} ]

これをうまく積分させるには、τの係数が負であることを確認してから変数で置き換えて実行します。

まずτの係数が負であることを確認します。

coef = collect(log_q1_tau, tau).coeff(tau)
coef
sol = solve(diff(coef, Y_vec[0]), Y_vec[0])[0]
sol #=> mu0
replacements = [(var, sol) for var in Y_vec]
coef.subs(replacements) #=> -b0
  • 1行目: τの係数coefを取得しています。
  • 3行目: coefの最大値が負であることを示せばOKです。まずはY_vec[0]についてcoefが最大になる値を探します。それは微分して0(&2階微分が負)になる点を求めればOKです。Y_vec[0]と他のY_vec[*]は区別がある形ではないので、Y_vec[*]についても同じ点でcoefが最大となります。
  • 5~6行目: それをまとめて代入しています。最大値は-b0と分かるので、τの係数は負であることがわかります。

次に変数で置き換えて積分します。

xi = symbols('xi', positive=True)
const = simplify(integrate(tau**(a0+1)*exp(-xi*tau), (tau, 0, oo)))
const = const.subs(xi, -coef)
const
q1_tau = 1/const * q1_tau
q1_tau

constが規格化定数になります。以下の部分です。

 (Y122Y1μ0+Y222Y2μ0+Y322Y3μ0+b0+3μ022)a02Γ(a0+2)

q1_tauは正規化された分布のq(τ)です。以下になります。

 τa0+1Γ(a0+2)(Y122Y1μ0+Y222Y2μ0+Y322Y3μ0+b0+3μ022)a0+2eτ(Y122+Y1μ0Y222+Y2μ0Y322+Y3μ0b03μ022)

このように解析解を求めることはコンセプトの理解に役立ちます。一方で、積分を繰り返して事後分布q(μ,τ)が収束するか確認するようなことは数値的に求めた方が分かりやすいです。

数値的に求める

仮に得られたデータY_vec1.1,1.0,1.3とします。また、事前分布はa0 = 1, b0 = 1, mu0 = 0, lambda0 = 1とします。

replacements = [(a0, 1), (b0, 1), (mu0, 0), (lambda0, 1)]
data_vec = [1.1, 1.0, 1.3]
replacements.extend([(var, val) for var, val in zip(Y_vec, data_vec)])
log_p_for_mu_subs = log_p_for_mu.subs(replacements)
log_p_for_tau_subs = log_p_for_tau.subs(replacements)
[log_p_for_mu_subs, log_p_for_tau_subs]

 [2μ2τ+3.4μτ,τ(2μ2+3.4μ2.95)+2log(τ)]

  • 1行目: 事前分布の分の代入を作っています。
  • 2~3行目: データの分の代入を追加しています。

τの初期分布をp_tauとして、q(μ)を求める→q(τ)を求める→q(μ)を求める→...と7回ほど繰り返してみます。

q_tau = N(p_tau.subs(replacements))
q_tau

for i in range(7):
    log_q_mu = N(integrate(log_p_for_mu_subs * q_tau, (tau, 0, oo)))
    const = N(integrate(exp(log_q_mu), (mu, -oo, oo)))
    q_mu = 1/const * exp(log_q_mu)

    log_q_tau = N(integrate(log_p_for_tau_subs * q_mu, (mu, -oo, oo)))
    const = N(integrate(exp(log_q_tau), (tau, 0, oo)))
    q_tau = 1/const * exp(log_q_tau)

    [q_mu, q_tau]

 [0.188098154753774e2.0μ2+3.4μ,4.03007506250001τ2.0e2.005τ]  [0.112320150163227e2.99251870324189μ2+5.08728179551122μ,3.11052191637731τ2.0e1.83916666666667τ]  [0.0965024138432034e3.26234707748074μ2+5.54599003171727μ,2.97238456804457τ2.0e1.81152777777778τ]  [0.0938011432750369e3.31212144445296μ2+5.63060645557004μ,2.94976700750461τ2.0e1.8069212962963τ]  [0.0933494031271016e3.32056521349236μ2+5.64496086293701μ,2.94600860516469τ2.0e1.80615354938272τ]  [0.0932740721319143e3.32197669575248μ2+5.64736038277921μ,2.94538251532282τ2.0e1.80602559156379τ]  [0.0932615158350226e3.32221205946743μ2+5.64776050109463μ,2.94527817564072τ2.0e1.80600426526063τ]

  • 1行目: N関数は数値による近似を求める関数です。

7回ほどの繰り返しのあとでほぼ収束していそうなことがわかります。

最後に求めた事後分布(の近似)q(μ,τ)=q(μ)q(τ)を可視化してみましょう。SymPyにもsympy.plottingsympy.plotting.plotが存在するのですが、ちょっと凝った図を書こうとするとすぐ厳しくなってしまいます。そこで、得られた事後分布をlambdify関数で関数化し、NumPyとMatplotlibで描くのが拡張性が高くてオススメです。

from sympy.utilities.lambdify import lambdify
import numpy as np
import matplotlib.pyplot as plt

delta = 0.05
x = np.arange(-1.0, 3.0, delta)
y = np.arange(0.0, 6.0, delta)
X, Y = np.meshgrid(x, y)
func = lambdify((mu, tau), q_mu * q_tau, 'numpy')
Z = func(X, Y)

plt.figure()
CS = plt.contour(X, Y, Z)
plt.clabel(CS, inline=1, fontsize=10)

まとめ

  • SymPyはデータサイエンスや機械学習の書籍や論文を読み進める上で、非常に有用な補助ツールです。
  • 現状では細かいところでMathematicaにまだ負けていると思います。プロにはMathematicaがオススメ。オープンソース重視の人やPython好きな人にはSymPyがオススメ。
  • 式変形には「一般的な場合のようにコンセプトが重要で深く理解しなければならない式変形」と「SymPyなどの数式処理ソフトで追えれば十分であるような式変形」があると個人的に思っています。専門書や技術書を執筆する場合は、その二つを区別すると読者にとって親切かなぁと思いました。

Enjoy!

謝辞

北大電子研の佐藤勝彦氏に感謝します。僕が院生の頃に輪読していた ニコリス プリゴジーヌ『散逸構造』の例題をMathematicaで10分ぐらいで一般解を求めるという衝撃のデモを見せてもらい、その後もたまにMathematicaを教えてもらい、数式処理を学ぶきっかけをもらいました。

*1:いつもはStanとの相性を考えてNormal(,)で書いてますが、この記事ではNormal(,)で書きます。

*2:Stanもね。