StatsFragments

Python, R, Rust, 統計, 機械学習とか

Python XGBoost の変数重要度プロット / 可視化の実装

Gradient Boosting Decision Tree の C++ 実装 & 各言語のバインディングである XGBoost、かなり強いらしいという話は伺っていたのだが自分で使ったことはなかった。こちらの記事で Python 版の使い方が記載されていたので試してみた。

puyokw.hatenablog.com

その際、Python でのプロット / 可視化の実装がなかったためプルリクを出した。無事 マージ & リリースされたのでその使い方を書きたい。まずはデータを準備し学習を行う。

import numpy as np
import xgboost as xgb
from sklearn import datasets

import matplotlib.pyplot as plt
plt.style.use('ggplot')

xgb.__version__
# '0.4'

iris = datasets.load_iris()
dm = xgb.DMatrix(iris.data, label=iris.target)

np.random.seed(1) 

params={'objective': 'multi:softprob',
        'eval_metric': 'mlogloss',
        'eta': 0.3,
        'num_class': 3}

bst = xgb.train(params, dm, num_boost_round=18)

1. 変数重要度のプロット

Python 側には R のように importance matrix を返す関数がない。GitHub 上でも F score を見ろ、という回答がされていたので F score をそのままプロットするようにした。

xgb.plot_importance(bst)

f:id:sinhrks:20150826235007p:plain

棒グラフの色、タイトル/軸のラベルは以下のように変更できる。

xgb.plot_importance(bst, color='red', title='title', xlabel='x', ylabel='y')

f:id:sinhrks:20150826235022p:plain

color にリストを渡せば棒ごとに色が変わる。色の順序は matplotlibbarh と同じく下からになる。また、ラベルを消したい場合は None を渡す。

xgb.plot_importance(bst, color=['r', 'r', 'b', 'b'], title=None, xlabel=None, ylabel=None)

f:id:sinhrks:20150826235030p:plain

XGBoost は内部的に変数名を保持していない。変数名でプロットしたい場合は 一度 F score を含む辞書を取得して、キーを差し替えてからプロットする。

bst.get_fscore()
# {'f0': 17, 'f1': 16, 'f2': 95, 'f3': 59}

iris.feature_names
# ['sepal length (cm)',
#  'sepal width (cm)',
#  'petal length (cm)',
#  'petal width (cm)']

mapper = {'f{0}'.format(i): v for i, v in enumerate(iris.feature_names)}
mapped = {mapper[k]: v for k, v in bst.get_fscore().items()}
mapped
# {'petal length (cm)': 95,
#  'petal width (cm)': 59,
#  'sepal length (cm)': 17,
#  'sepal width (cm)': 16}

xgb.plot_importance(mapped)

f:id:sinhrks:20150826235041p:plain

2. 決定木のプロット

以下二つの関数を追加した。graphviz が必要なためインストールしておくこと。

  • to_graphviz: 任意の決定木を graphviz インスタンスに変換する。IPython 上であればそのまま描画できる。
  • plot_tree: to_graphviz で取得した graphviz インスタンスmatplotlibAxes 上に描画する。

IPython から実行する。num_trees で指定した番号に対応する木が描画される。

xgb.to_graphviz(bst, num_trees=1)

f:id:sinhrks:20150826235056p:plain

エッジの色分けが不要なら明示的に黒を指定する。

xgb.to_graphviz(bst, num_trees=2, yes_color='#000000', no_color='#000000')

f:id:sinhrks:20150826235105p:plain

IPython を使っていない場合や、サブプロットにしたい場合には plot_tree を利用する。

_, axes = plt.subplots(1, 2)
xgb.plot_tree(bst, num_trees=2, ax=axes[0])
xgb.plot_tree(bst, num_trees=3, ax=axes[1])

f:id:sinhrks:20150826235251p:plain

何かおかしいことをやっていたら 本体の方で issue お願いします。

10/3追記 その後の修正を以下にしました。変数名の指定などが簡単になっています。

sinhrks.hatenablog.com