めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

PRML Figure 5.21 を再現するコード

何の話かというと

A Neural Representation of Sketch Drawings でスケッチの次のストロークを予測するモデルとして、混合ガウス分布が使われており、ガウス分布の混合係数、平均、分散を Latent Variable z を入力とする RNN で計算するという手法が用いられています。

上図のデコーダ部分の出力 y が混合係数、平均、分散にあたります。その後、この分布から次のストロークのサンプルを取得することで、非決定的に画像を生成します。

このモデルは、Bishop先生のMixture Density Networksが元ネタになっており、PRMLにも解説があります。そこで、勉強のためにPRMLで紹介されているサンプルをTensorFlowで実装してみました。

モデルの説明

座標 x に依存して、平均と分散が変化する正規分布 N(tμ(x),σ2(x)) を3つ混合したモデルを考えます。

p(tx)=k=13πk(x)N(tμk(x),σk2(x))
k=13πk(x)=1

ここで、混合係数 πk(x) も座標に依存します。さらに、平均、分散、混合係数の x に対する依存性は、ニューラルネットワークで計算されます。ここでは、一例として、5ノードの隠れ層を1層だけ持つモデルを使用します。

このモデルを用いて、下記のデータセットを学習すると、3つのパートを個別の正規分布でフィッティングできるものと期待されます。

誤差関数には、対数尤度の符号違いを用います。

E(θ)=n=1Nlog{k=13πk(xn,θ)N(tnμk(xn,θ),σk2(xn,θ))}

TensorFlowを用いて実装した結果が下記になります。

実際のコードはこちらです。

Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

Mixture Density Network