Gradient Boosting Decision Tree の C++ 実装 + 各言語のバインディングである XGBoost
、かなり強いらしいという話は伺っていたのだが自分で使ったことはなかった。こちらの記事で Python 版の使い方が記載されていたので試してみた。
その際、XGBoost
Python 側でのプロット / 可視化の実装がなかったためプルリクを出した。無事 マージ & リリースされたのでその使い方を書きたい。まずはデータを準備し学習を行う。
import numpy as np import xgboost as xgb from sklearn import datasets import matplotlib.pyplot as plt plt.style.use('ggplot') xgb.__version__ # '0.4' iris = datasets.load_iris() dm = xgb.DMatrix(iris.data, label=iris.target) np.random.seed(1) params={'objective': 'multi:softprob', 'eval_metric': 'mlogloss', 'eta': 0.3, 'num_class': 3} bst = xgb.train(params, dm, num_boost_round=18)
1. 変数重要度のプロット
Python 側には R のように importance matrix を返す関数がない。GitHub 上でも F score を見ろ、という回答がされていたので F score をそのままプロットするようにした。
xgb.plot_importance(bst)
棒グラフの色、タイトル/軸のラベルは以下のように変更できる。
xgb.plot_importance(bst, color='red', title='title', xlabel='x', ylabel='y')
color
にリストを渡せば棒ごとに色が変わる。また、ラベルを消したい場合は None
を渡す。
xgb.plot_importance(bst, color=['r', 'r', 'b', 'b'], title=None, xlabel=None, ylabel=None)
xgboost
は内部的に変数名を保持していないため、変数名でプロットしたい場合は 一度 F score を取得して変数名に差し替えてからプロットする。
bst.get_fscore() # {'f0': 17, 'f1': 16, 'f2': 95, 'f3': 59} iris.feature_names # ['sepal length (cm)', # 'sepal width (cm)', # 'petal length (cm)', # 'petal width (cm)'] mapper = {'f{0}'.format(i): v for i, v in enumerate(iris.feature_names)} mapped = {mapper[k]: v for k, v in bst.get_fscore().items()} mapped # {'petal length (cm)': 95, # 'petal width (cm)': 59, # 'sepal length (cm)': 17, # 'sepal width (cm)': 16} xgb.plot_importance(mapped)
2. 決定木のプロット
以下二つの関数を追加した。graphviz
が必要なためインストールしておくこと。
to_graphviz
: 任意の決定木をgraphviz
インスタンスに変換する。IPython
上であればそのまま描画できる。plot_tree
:to_graphviz
で取得したgraphviz
インスタンスをmatplotlib
のAxes
上に描画する。
IPython
から実行する。num_trees
で指定した番号に対応する木が描画される。
xgb.to_graphviz(bst, num_trees=1)
エッジの色分けが不要なら明示的に黒を指定する。
xgb.to_graphviz(bst, num_trees=2, yes_color='#000000', no_color='#000000')
IPython
を使っていない場合や、サブプロットにしたい場合には plot_tree
を利用する。
_, axes = plt.subplots(1, 2) xgb.plot_tree(bst, num_trees=2, ax=axes[0]) xgb.plot_tree(bst, num_trees=3, ax=axes[1])
何かおかしいことをやっていたら 本体の方で issue お願いします。