読者です 読者をやめる 読者になる 読者になる

StatsFragments

Python, R, Rust, 統計, 機械学習とか

Chainer で Deep Learning: Bokeh で Live Monitoring したい

Chainer 可視化 Deep Learning Python

概要

Deep Learning の学習には時間がかかるため、進捗が都度 確認できるとうれしい。その際、テキストのログ出力では味気ないので、リアルタイムでプロットを眺めたい。

いくつかの Deep Learning パッケージではそういった機能 (Live Monitoring) が提供されている。

同じことを Chainer でやりたい。自分は EC2 を使うことが多いので、リモート環境でも利用できるものがいい。そのため、ここでは Bokeh を使うことにした。

Bokeh とは

Bokeh とは、D3.js を利用したブラウザベースのインタラクティブな可視化を実現するパッケージ。どんなものかは公式の Gallery が充実しているのでそちらを。

補足 R 用のパッケージ {rbokeh} もある。

インストール

pip で。

$ pip install bokeh

準備

環境は EC2 上に作成する。Bokeh からは、ファイル、IPythonBokeh 組み込みのWebサーバ ( bokeh-server ) 上の画面 の 3種類に対して出力することができるが、今回は Ipython を使うことにする。

補足 もっとも、常に IPython を使うわけではないため、 bokeh-server 上でも描画できるようにしたい。少し試したが、EC2 上に Bokeh を置くと bokeh-server の画面は開くが個々のプロットが表示できなかったため諦めた (ローカルでは問題ない)。できた方いたらやり方教えてください。

bokeh-server はプロットの描画画面だけでなく、動的なデータ更新のためのAPIも提供している。この例では 学習の進捗をリアルタイムで描画するために bokeh-server の機能を利用する。bokeh-server はシェルから以下のコマンドで起動できる。このとき、外部からの接続を受け入れるには自身のホスト名 / IP を引数 ip として指定する。

$ bokeh-server --ip=ec2-xxx-xxx-xxx-xxx.ap-northeast-1.compute.amazonaws.com

補足 EC2 では (特に設定していなければ) グローバル/プライベートの IP が異なるため、IP 指定では正しくデータが更新されないようだ。そのため、パブリックDNSを指定した。

Live Monitoring

以降は IPython Notebook から。Live Monitoring のためのプロットを行うクラスを定義する。

import numpy as np
import chainer

import bokeh.plotting as plotting
from bokeh.models import GlyphRenderer


class LiveMonitor(object):

    def __init__(self, server='chainer', url='http://localhost:5006/', **kwargs):
        # 出力先に IPython Notebook を指定
        plotting.output_notebook(url=url)

        # トレーニング/テストデータの loss, accuracy を描画する figure を定義
        self.train_loss = self._initialize_figure(title='Train loss', color='#FF0000', **kwargs)
        self.train_acc = self._initialize_figure(title='Train accuracy', color='#0000FF', **kwargs)
        self.test_loss = self._initialize_figure(title='Test loss',  color='#FF0000', **kwargs)
        self.test_acc = self._initialize_figure(title='Test accuracy', color='#0000FF', **kwargs)

        # figure を 2 x 2 のグリッド (サブプロット) として配置
        self.grid = plotting.gridplot([[self.train_loss, self.test_loss],
                                       [self.train_acc, self.test_acc]])

        # グリッドを描画
        plotting.show(self.grid)

    def _initialize_figure(self, color=None, line_width=2,
                           title=None, title_text_font_size='9pt',
                           plot_width=380, plot_height=280):
        """ figure の初期化用のメソッド"""

        figure = plotting.figure(title=title, title_text_font_size=title_text_font_size,
                                 plot_width=plot_width, plot_height=plot_height)

        # 空のデータで折れ線グラフを作成
        x = np.array([])
        y = np.array([])
        figure.line(x, y, color=color, line_width=line_width)
        return figure

    def update(self, train_loss=None, train_accuracy=None,
               test_loss=None, test_accuracy=None):
        """
        プロットを更新するためのメソッド
        指定したキーワード引数に対応する figure が更新される
        """
        self._maybe_update(self.train_loss, train_loss)
        self._maybe_update(self.train_acc, train_accuracy)
        self._maybe_update(self.test_loss, test_loss)
        self._maybe_update(self.test_acc, test_accuracy)

    def _maybe_update(self, figure, value):
        """ figure の値を更新するメソッド"""

        if value is not None:

            # Variable から np.array に戻す
            if isinstance(value, chainer.Variable):
                value = chainer.cuda.to_cpu(value.data)

            # figure が利用している data_source を取得
            renderer = figure.select(dict(type=GlyphRenderer))
            ds = renderer[0].data_source

            # data_source 中の値を更新
            y = np.append(ds.data['y'], value)
            ds.data['y'] = y
            ds.data['x'] = np.arange(len(y))

            # session へ返す (とプロットが更新される)
            plotting.cursession().store_objects(ds)

このクラスをインスタンスにする。URL としては、bokeh-server の起動時に指定したものを 以下 URL の形式で指定する。

monitor = LiveMonitor(url="http://ec2-xxx-xxx-xxx-xxx.ap-northeast-1.compute.amazonaws.com:5006/")

あとは 学習 / テスト中に、monitor.update に適当な引数を渡せばよい。chainer/examples/mnist/train_mnist.py を例にすると、

# ... 略
        optimizer.zero_grads()
        loss, acc = forward(x_batch, y_batch)
        loss.backward()
        optimizer.update()

        monitor.update(train_loss=loss, train_accuracy=acc)  

# ... 略

        loss, acc = forward(x_batch, y_batch, train=False)
        monitor.update(test_loss=loss, test_accuracy=acc)

IPython 上 ( LiveMonitor インスタンスを作成した直後) のセルに、以下のようなプロットが表示される。プロットは update が呼ばれるたびに更新される。

f:id:sinhrks:20150709230044p:plain

また、各プロットはインタラクティブにズーム/パンといった操作ができる (左上)。

f:id:sinhrks:20150709230053p:plain

まとめ

これでリアルタイムに進捗を眺めて楽しむことができる。

深層学習 (機械学習プロフェッショナルシリーズ)

深層学習 (機械学習プロフェッショナルシリーズ)