唯物是真 @Scaled_Wurm

プログラミング(主にPython2.7)とか機械学習とか

サザエさんのジャンケンの次の手を決定木で予測+可視化してみた

前に決定木の可視化をしようと思ってやってなかったのでやっておきます

決定木のライブラリは例のごとくscikit-learnを使う

pythonの機械学習ライブラリscikit-learnの紹介 - 唯物是真 @Scaled_Wurm

決定木とは

決定木は教師あり学習で使われるモデルで、ルールを木として学習します
例えば身長、体重から性別を予測したい場合、身長が170cm以上で体重60kg以上なら男、みたいなルールを学習します
性能はあまりよくないモデルですが、人間にもわかりやすいルールを出力する(他のモデルと比べれば)という特徴があります

簡単に説明すると、ある変数が一定値以上であるかという条件で分けた時に、データのラベル(性別なら男女)ごとの分布がどちらかに偏るような条件で木を作っていきます
予測するときには、データが条件を満たしているノードをたどって木の一番下の葉ノードまでいって、葉ノードに割り当てられた学習データが一番多いラベルに分類されます

可視化すると以下のようになります
これはこの記事の下の方でやるサザエさんのジャンケン予測の決定木の簡単な例です
valueはその葉ノードに割り当てられたグー、チョキ、パーそれぞれのデータの数、giniはラベルの分布の偏りの指標です
f:id:sucrose:20141121235829p:plain
条件が真の時に1と定義しているので、1手前がチョキ以外なら左側の葉ノードにたどりつくので次はチョキ、1手前がチョキだったら右側の葉ノードにたどりつくので次はパーという予測ルールを表していることになります

サザエさんのジャンケン予測

適当なデータで試してみようということでサザエさんのジャンケンの次の手を予測します

先行研究によれば正解率50%以上で予測できることが知られています

サザエさんのじゃんけん予測問題のサーベイ - 唯物是真 @Scaled_Wurm

データは過去にクロールしたのがそのままあったので使います↓
Janken_Classification/sazae.tri at master · mugenen/Janken_Classification · GitHub

サザエさん(とプリキュア)のジャンケンデータのダウンロード - 唯物是真 @Scaled_Wurm

過去の3回までの手から次の手を予測します
20分割のクロスバリデーション(データの19/20で学習して残りの1/20で評価するのを20回繰り返す)で正解率を評価して(マクロ)平均を出力します

scikit-learnの決定木のライブラリでは木の深さ(≒ルールの複雑さ)の最大値を指定できるので、変えて実験しています

ある程度の深さまでは正解率が上がりますが、増やし過ぎると正解率が下がります

木の最大の深さ 正解率の平均
1 42.6%
2 48.4%
3 49.9%
4 51.2%
5 51.5%
6 51.3%

正解率を出すのに使ったコード

# -*- coding: utf-8 -*-
import sklearn.tree
import sklearn.datasets
import sklearn.cross_validation

#データの読み込み
X, y = sklearn.datasets.load_svmlight_file('sazae.tri.txt')

#深さを変えて実験
for d in xrange(1, 7):
    clf = sklearn.tree.DecisionTreeClassifier(max_depth=d)
    result = sklearn.cross_validation.cross_val_score(clf, X.toarray(), y, cv=20)
    print u'最大深さ: {}, 正解率の平均: {:.1%}'.format(d, result.mean())

決定木の可視化

ようやく本題ですが、scikit-learnには可視化用の関数があるので簡単にできます
sklearn.tree.export_graphviz関数に決定木を与えれば、グラフ可視化用のソフトのgraphvizの形式で出力してくれます

以下のスクリプトを実行してできたファイルに対して、"dot -Tpng 入力ファイル名 -o 出力ファイル名"とコマンドを実行するとpng形式で得られます(Graphvizのインストールが必要)
フォントを指定しないと日本語が表示できないので注意

# -*- coding: utf-8 -*-
import sklearn.tree
import sklearn.datasets
import StringIO
import contextlib

X, y = sklearn.datasets.load_svmlight_file('sazae.tri.txt')

clf = sklearn.tree.DecisionTreeClassifier(max_depth=2)

clf.fit(X.toarray(), y)

with contextlib.closing(StringIO.StringIO()) as temp:
    sklearn.tree.export_graphviz(clf, out_file=temp, feature_names='グー(3手前) チョキ(3手前) パー(3手前) グー(2手前) チョキ(2手前) パー(2手前) グー(1手前) チョキ(1手前) パー(1手前)'.split())
    output = temp.getvalue().splitlines()

#日本語を表示するときはフォントを指定しないといけない
output.insert(1, 'node[fontname="meiryo"];')
with open('tree.dot', 'w') as f:
    f.write('\n'.join(output))

どんなルールを学習できたか

決定木の可視化は以下のような出力になります
f:id:sucrose:20141121230323p:plain
慣れないとわかりづらいと思ったので左上に赤字で予測結果、下にその葉ノードに辿り着く過去のジャンケンの手を書いておきました

深さ1までの場合。こんな単純なルールでも4割以上当たります
f:id:sucrose:20141121235829p:plain
深さ2まで。過去に出た手以外を選ぶようなルールが学習されています
f:id:sucrose:20141122191609p:plain
深さ3以上は木が大きくてわかりづらいので省略しました

-->