Vingow 開発ブログ
やみつきー
2014年4月9日

numpyでのKLダイバージェンスとJensen-Shannonダイバージェンスの実装

KL divergence and JS divergence with numpy

scipyには距離を測るための手続きが用意されています(scipy.spatial.distance)。ユークリッド距離cosine距離(cosine類似度)などもあるのですが、確率分布間の距離とも言うべきKLダイバージェンスやJensen-Shannonダイバージェンスなどは実装されていません。ということで、実装してみました。

実装コードだけ見たいという場合は、最後まで読み飛ばしてください。

KLダイバージェンスとJensen-Shannonダイバージェンスについて

KLダイバージェンス(カルバック・ライブラー情報量; Kullback–Leibler divergence; 相対エントロピー)とは、分布と分布の差異の大きさ(≠距離)を測るものです。分布p と分布q があったとき、p q に対するKLダイバージェンスは

\mathbb {KL}(p||q) = \sum_{k=1}^{K} p_k \log \frac{p_k}{q_k}

で定義されます。また、クロスエントロピーを使って

\mathbb {KL}(p||q) = - \mathbb {H}(p) + \mathbb {H}(p,q)

と定義することもできます。クロスエントロピー\mathbb {H}(p,q) は、真の分布p を、分布q のモデルで表現した際に必要な平均ビット数です。つまりKLダイバージェンスは、真のモデルp で表現したときに比べ、モデルq で表現すると、どれだけ余分にビット数が必要になるか、ということを表しています(p q で表したビット数 - p p で表したビット数)。

話を戻して性質を見てみると、p q が等しいときは

\mathbb {KL}(p||p) = - \mathbb {H}(p) + \mathbb {H}(p,p) = 0

となり、KLダイバージェンスは0です。加えて、\log とかの値域を考えると

\mathbb {KL}(p||q) \geq 0

となるので、KLダイバージェンスは「分布が等しいと0。分布の差異が大きくなるに連れ、KLダイバージェンスも大きくなる」ような値だと見ることができます。

なんだか距離っぽいのですが、\mathbb {KL}(p||q) \neq \mathbb {KL}(q||p) なので、距離とは呼びません(非対称だから)。家からコストコと、コストコから家までの距離は、等しくないといけないのですね。ということで、KLダイバージェンスに対称性を持たせたものがJensen-Shannonダイバージェンスで、次のように定義されます。

\mathbb {JS}(p_1, p_2) = \frac{1}{2} \mathbb {KL}(p_1 || q) + \frac{1}{2} \mathbb {KL}(p_2 || q)  ただし q = \frac{1}{2} p_1 + \frac{1}{2} p_2

確率分布間の距離というと、Webサービスのような実アプリケーションでは使うことがないと考える方がいるかもしれないのですが、自然言語処理の分野では文書を確率分布として表現したり(LDA; latent Dirichlet allocation など)することがあります。vingowでも、これらの確率分布をレコメンデーションなどに応用したりしています(実際はJSダイバージェンスを使わず、単純にベクトルとして類似度を扱ってしまいましたが……)。

numpyでの実装

numpyで実装してみたコードが以下のものになります。

import numpy as np

def kld(p, q):
    """Calculates Kullback–Leibler divergence"""
    p = np.array(p)
    q = np.array(q)
    return np.sum(p * np.log(p / q), axis=(p.ndim - 1))

def jsd(p, q):
    """Calculates Jensen-Shannon Divergence"""
    p = np.array(p)
    q = np.array(q)
    m = 0.5 * (p + q)
    return 0.5 * kld(p, m) + 0.5 * kld(q, m)

ということで、上記の性質を確かめてみましょう。

# あるトピック分布(話題の分布)を持つ文書を想定
doc_A = [0.7, 0.2, 0.1]
# 似たような話題の文書
doc_A2 = [0.8, 0.1, 0.1]
# 全然似てない文書
doc_B = [0.1, 0.8, 0.1]

assert kld(doc_A, doc_A) == 0, "KLD b/w same prob must be 0"
assert kld(doc_A, doc_A2) >= 0, "KLD >= 0"
assert kld(doc_A, doc_A2) < kld(doc_A, doc_B), "More diff prob, larger KLD"
assert kld(doc_A2, doc_A) != kld(doc_A, doc_A2), "Diff order, diff KLD(asymmetric)"

assert jsd(doc_A, doc_A) == 0, "JSD b/w same prob must be 0"
assert jsd(doc_A, doc_A2) >= 0, "JSD >= 0"
assert jsd(doc_A, doc_A2) < jsd(doc_A, doc_B), "More diff prob, larger JSD"
assert jsd(doc_A2, doc_A) == jsd(doc_A, doc_A2), "Diff order, same JSD(symmetric)"

無事、assertが通ると思います(Python 2.7.5 :: Anaconda 1.8.0 (x86_64))

参考文献

  • Machine Learning: a Probabilistic Perspective, Kevin Patrick Murphy, The MIT Press, 2012

ということで、MLaPP(Murphy本)でした!

Vingowを一緒につくりませんか?

弊社ではVingowを一緒に成長させてくれるエンジニア、デザイナー、ディレクターを募集しています。勢いのある若手エンジニアから、若いチームを引っ張ってくれる経験豊富な方まで、Vingowに興味がある方はぜひご連絡下さい。まずは、気軽にランチでも行きましょう!

ランチに行ってみる
«