読者です 読者をやめる 読者になる 読者になる

機械学習モデルの予測結果を説明するための力が欲しいか...?

はじめに

最近はAIや機械学習などの単語がビジネスで流行っていて、世はAI時代を迎えている。QiitaやTwitterを眺めているとその影響を受けて、世の多くのエンジニアがAIの勉強を始め出しているように見受けられる。

さらに、近年では機械学習のライブラリも充実しており、誰でも機械学習を実装することができる良い時代になってきた。

その一方で、特徴選択を行い精度を向上させたり、機械学習の出した答えがどの特徴に基づいて判断されたのかを理解したりするには、モデルに対する理解やテクニックが必要となる場合も多々ある。複雑なモデルになると人間には解釈が困難で説明が難しい。近頃流行りのDeep Learning系のモデルだと頻繁に「なんかよくわからないけどうまくいきました」となっていると思う。

一般的なエンジニアとしては、この点が割と課題なんじゃないかと勝手に思っている。というか、私が課題に感じている。(特に実業務で機械学習していない上に、エンジニアでもないが)

そんなわけで、今回はこの課題を解決するためのツールであるLIME(Local Interpretable Model-agnostic Explainations)が興味深かったので、紹介していこうかと思う。
※本記事はLIMEのアルゴリズムの説明となるため、LIMEを実際に利用したい方はGitHub - marcotcr/lime: Lime: Explaining the predictions of any machine learning classifierpythonのライブラリインストール方法とチュートリアルが載っているので、そちらをご参照ください。

モデルの説明とは何か

LIMEの紹介に移る前に機械学習モデルを説明するとはどういうことなのか整理していきたい。
機械学習モデルの説明には下記の説明の2種類が考えられる。

  • explaining prediction(予測の説明)

データ一つに対する機械学習モデルの分類器による予測結果に対して、どうして分類が行われたのかを説明すること。(下記の図はイメージ)
f:id:gat-chin321:20170107164420p:plain
(出典: https://arxiv.org/pdf/1602.04938.pdf)

  • explaining models(モデルの説明)

分類器がどういう性質を持っているのかを説明すること。
https://d3ansictanv2wj.cloudfront.net/figure2-802e0856e423b6bf8862843102243a8b.jpg
(出典: Introduction to Local Interpretable Model-Agnostic Explanations (LIME) - O'Reilly Media)

LIMEはこのうち、explaining predictionを行うためのアルゴリズムである。
explaining modelsについては、SP-LIMEと呼ばれるアルゴリズムが論文に記載されているので、そちらを参照されたし。(気が向けば、SP-LIMEについても記事を書く)

LIME(Local Interpretable Model-agnostic Explainations)の紹介

LIMEとは?

KDD2016で採択された『“Why Should I Trust You?” Explaining the Predictions of Any Classifier』というタイトルの論文で発表されたアルゴリズム。分類器がどのように判断してラベリングを行なったのかを人間でも解釈できるような形で提示してくれる。
このアルゴリズムはあるデータを分類した結果、それぞれの特徴がどの程度分類に貢献しているかを調べることで分類器の予測結果を説明している。また、分類器の予測結果を用いるため、任意の分類器に適用できる特徴がある。

LIMEのアイデア

データxの周辺からサンプリングしたデータを用いて、説明したい分類器の出力と近似するように解釈可能な(かつ単純な)モデルを学習させる。その後、得られた分類器を用いて分類結果の解釈を行う。下記がイメージ図(論文から抜粋した図を編集)。

f:id:gat-chin321:20170106191925p:plain
(出典: https://arxiv.org/pdf/1602.04938.pdf)

説明用分類器の学習方法

説明用分類器 gはデータxの周辺でfの結果と近似するようにしたい。
そうするために、下記の目的関数を利用して学習する。

{\displaystyle 
\DeclareMathOperator*{\argmin}{arg\,min}
\begin{equation}
\xi(x) = \argmin_{g \in G} L(f, g, \pi_x) + \Omega(g)
\end{equation}
}

  • G : 解釈可能なモデルの集合
  • g : Gのうちの一つのモデル。例えば、線形モデルなど
  • f : 説明したい分類器
  • \pi_x : データxとの距離
  • {\displaystyle L(f, g, \pi_x)} : データxの周辺でfgの結果がどれだけ違っているか(Lは損失関数ともいう)
  • {\displaystyle \Omega(g)} : 説明用分類器gの複雑さ

上記の内容から、\xi(x)はデータxの周辺でfgの結果についての食い違い{\displaystyle L(f, g, \pi_x)}と説明用分類器gの複雑さの和を最小にする g の集合を求めるものであると言える。
ここで、{\displaystyle \Omega(g)}はテキスト分類の場合、解釈可能なモデルの特徴表現を単語の有無{0,1}のBag-of-Words法(単語袋詰め)とし、単語の数(次元数)に限度Kを設定することで、説明が解釈可能であることを保証するためのものらしい。
画像データの場合はsuper-pixelsと呼ばれる任意のアルゴリズムを使用して計算されるものを用いて解釈可能なモデルの特徴表現とする。
ここで、この特徴表現は{0,1}の2値で表され、1は元のsuper-pixels、0はグレーアウトされたsuper-pixelsを示す。
ここまでで\xi(x)について、何となくというレベルでは理解ができたと思いたい。
そこで、次は{\displaystyle L(f, g, \pi_x)} の数式についても見ていこう。

{\displaystyle 
L(f, g, \pi_x) = \sum_{z,z' \in Z } \pi_x (z) (f(z) - g(z'))^2
}

  •  Z :  xの周辺のデータの集合
  •  z' : 非ゼロ要素を一部だけ含むサンプリングにより生成された2値のスパースな点。

  z' \in \{0,1\}^dで定義される

  •  z :  z'を用いて復元された元のサンプルの特徴表現。 z \in R^dで定義される

この式を見る限り、 xの周辺のデータにおける\pi_x (z)で重み付けした残差平方和を出している。
残差平方和自体は正解データ(今回の場合、説明したい分類モデルの予測結果)と推定モデルの予測結果との間の不一致を評価する尺度なので、わかりやすいかと思う。
また、\pi_x (z)で重み付けしている理由について理解するため、\pi_x (z)の式を見ていこう。

{\displaystyle 
\pi_x (z) = exp\Bigl(\frac{-D(x,z)^2}{\sigma^2}\Bigr)
}

  •  D(x,z) :  x zとの距離関数(例えば、テキストならコサイン類似度、画像ならL2ノルムなどを利用する)
  •  \sigma : 指数カーネルカーネル

\pi_x (z)の式はカーネル関数であり、xとzの2変数間の類似度を算出している。\pi_x (z)はテキトーに0から1までの値を入れて見て計算すればわかると思うが、サンプルが近ければ近いほど値が小さくなる。これで重み付けすることで、 x zとの距離が近いサンプルの場合は損失{\displaystyle L(f, g, \pi_x)}が小さくなりやすくなり、逆に距離が遠いサンプルの場合は損失が高くなる。この重み付けのおかげで、ロバストなモデルとなっている。
最後は\Omega(g)について掘り下げていく。\Omega(g)の式を見ていこう。

{\displaystyle 
\Omega(g) = \infty \mathbb{1} [||w_g||_0 > K ]
}

\Omega(g)は利用する特徴がたかだか単語数(もしくはsuper-pixels)K程度だけとすることを示しているっぽい。
利用する特徴\Omega(g)の選択は、方程式\xi(x)から直接解くことで実現することは難しい。
そのため、まず著者らがK-Lassoと呼んでいる、Lassoで正則化パスを使用して利用する特徴をK個選択し、最小二乗法を介して重みを学習する方法によって、利用する特徴\Omega(g)の選択についての解と近似させる。
これにより、方程式\xi(x)を解くことができるようになるため、線形モデル(Githubのコードを読む限りではRidge回帰)で学習を行う。
この学習した線形モデルの偏回帰係数を確認することで、選択された特徴について、どれだけ分類に貢献しているかの説明を行うことができる。

ここまで、説明した内容が下記の図のAlgorithm 1 である。
f:id:gat-chin321:20170107152919p:plain
(出典: https://arxiv.org/pdf/1602.04938.pdf)

Algorithm 1 は個々の予測についての説明を生成するので、その複雑さはデータセットのサイズに依存するのではなく、 f(x)を計算する時間とサンプル数 Nに依存するらしい。

著者らのLIMEパッケージを使うと、下記の図のように直感的な説明が表示されて使いやすそう。
f:id:gat-chin321:20170107170455p:plain
f:id:gat-chin321:20170107170504p:plain
f:id:gat-chin321:20170107170221p:plain

(出典: Lime - basic usage, two class case)

何かしらいい感じのデータが手に入ったら是非とも使うか、そのうちKaggleか何かやろうかな(やらないやつ)。

参考文献

LIME論文:

“Why Should I Trust You?” Explaining the Predictions of Any Classifier
https://arxiv.org/pdf/1602.04938.pdf

LIMEコード:

github.com