numpyでのKLダイバージェンスとJensen-Shannonダイバージェンスの実装
scipyには距離を測るための手続きが用意されています(scipy.spatial.distance)。ユークリッド距離やcosine距離(cosine類似度)などもあるのですが、確率分布間の距離とも言うべきKLダイバージェンスやJensen-Shannonダイバージェンスなどは実装されていません。ということで、実装してみました。
実装コードだけ見たいという場合は、最後まで読み飛ばしてください。
KLダイバージェンスとJensen-Shannonダイバージェンスについて
KLダイバージェンス(カルバック・ライブラー情報量; Kullback–Leibler divergence; 相対エントロピー)とは、分布と分布の差異の大きさ(≠距離)を測るものです。分布と分布
があったとき、
の
に対するKLダイバージェンスは
で定義されます。また、クロスエントロピーを使って
と定義することもできます。クロスエントロピーは、真の分布
を、分布
のモデルで表現した際に必要な平均ビット数です。つまりKLダイバージェンスは、真のモデル
で表現したときに比べ、モデル
で表現すると、どれだけ余分にビット数が必要になるか、ということを表しています(
を
で表したビット数 -
を
で表したビット数)。
話を戻して性質を見てみると、と
が等しいときは
となり、KLダイバージェンスは0です。加えて、とかの値域を考えると
となるので、KLダイバージェンスは「分布が等しいと0。分布の差異が大きくなるに連れ、KLダイバージェンスも大きくなる」ような値だと見ることができます。
なんだか距離っぽいのですが、なので、距離とは呼びません(非対称だから)。家からコストコと、コストコから家までの距離は、等しくないといけないのですね。ということで、KLダイバージェンスに対称性を持たせたものがJensen-Shannonダイバージェンスで、次のように定義されます。
ただし
確率分布間の距離というと、Webサービスのような実アプリケーションでは使うことがないと考える方がいるかもしれないのですが、自然言語処理の分野では文書を確率分布として表現したり(LDA; latent Dirichlet allocation など)することがあります。vingowでも、これらの確率分布をレコメンデーションなどに応用したりしています(実際はJSダイバージェンスを使わず、単純にベクトルとして類似度を扱ってしまいましたが……)。
numpyでの実装
numpyで実装してみたコードが以下のものになります。
import numpy as np def kld(p, q): """Calculates Kullback–Leibler divergence""" p = np.array(p) q = np.array(q) return np.sum(p * np.log(p / q), axis=(p.ndim - 1)) def jsd(p, q): """Calculates Jensen-Shannon Divergence""" p = np.array(p) q = np.array(q) m = 0.5 * (p + q) return 0.5 * kld(p, m) + 0.5 * kld(q, m)
ということで、上記の性質を確かめてみましょう。
# あるトピック分布(話題の分布)を持つ文書を想定 doc_A = [0.7, 0.2, 0.1] # 似たような話題の文書 doc_A2 = [0.8, 0.1, 0.1] # 全然似てない文書 doc_B = [0.1, 0.8, 0.1] assert kld(doc_A, doc_A) == 0, "KLD b/w same prob must be 0" assert kld(doc_A, doc_A2) >= 0, "KLD >= 0" assert kld(doc_A, doc_A2) < kld(doc_A, doc_B), "More diff prob, larger KLD" assert kld(doc_A2, doc_A) != kld(doc_A, doc_A2), "Diff order, diff KLD(asymmetric)" assert jsd(doc_A, doc_A) == 0, "JSD b/w same prob must be 0" assert jsd(doc_A, doc_A2) >= 0, "JSD >= 0" assert jsd(doc_A, doc_A2) < jsd(doc_A, doc_B), "More diff prob, larger JSD" assert jsd(doc_A2, doc_A) == jsd(doc_A, doc_A2), "Diff order, same JSD(symmetric)"
無事、assertが通ると思います(Python 2.7.5 :: Anaconda 1.8.0 (x86_64))
参考文献
- Machine Learning: a Probabilistic Perspective, Kevin Patrick Murphy, The MIT Press, 2012
ということで、MLaPP(Murphy本)でした!