読者です 読者をやめる 読者になる 読者になる

StatsFragments

Python, R, Rust, 統計, 機械学習とか

Python Theano function / scan の挙動まとめ

Python Theano

勉強のため たまに Pylearn2 など Theano を使ったパッケージのソースを眺めたりするのだが、theano.scan の挙動を毎回 忘れてしまう。繰り返し調べるのも無駄なので、一回 整理したい。theano.scan の動作は theano.function が前提となるため、あわせて書く。

準備

import numpy as np
import theano
import theano.tensor as T

theano.function

まずは Theano における関数にあたる Function インスタンスを作成する theano.function の基本的な挙動について。引数はいろいろあるが、特に重要と思われるのは以下の4つ。

  • inputs : Function への入力 (引数) に対応するシンボル。
  • outputs : Function 化される式。
  • updates : 共有変数を更新する式。
  • givens : 引数 ( inputs ) -> 何かへのマッピングを行う辞書。

引数 ひとつを受け取って 1 インクリメントした値を返す Function は以下のように作る。f1(3) の 3 が 定義中のシンボル a に対応する。inputs には、引数がひとつ もしくは ない場合でもリストを指定すること。

a = T.lscalar('a')
f1 = theano.function(inputs=[a], outputs=a + 1)
f1
# <theano.compile.function_module.Function at 0x10d00fb90>

f1(3)
# array(4)

補足 theano は内部的に 式をグラフ構造として処理している。その構造は debugprint などで表示できる。

theano.printing.debugprint(f1)
# Elemwise{add,no_inplace} [@A] ''   0
#  |TensorConstant{1} [@B]
#  |a [@C]

補足 また、グラフ構造はダイアグラムとしても描画できる。IPython (Jupyter) 上で描画する場合は、

import IPython.display as display

def plot(obj):
    svg = theano.printing.pydotprint(obj, return_image=True, format='svg')
    return display.SVG(svg)

plot(f1)

f:id:sinhrks:20150425222011p:plain

引数をふたつ受け取りそれらの和を返す Function は以下のようになる。

b = T.lscalar('b')
f2 = theano.function(inputs=[a, b], outputs=(a + b))
f2(3, 4)
# array(7)

また、各引数の既定値 や 別名は theano.Param を通すことで指定できる。以下の例では、ふたつめの引数 b が指定されない場合、既定値として 1 を使う。

f3 = theano.function(inputs=[a, theano.Param(b, default=1)], outputs=(a + b))
f3(3)
# array(4)

f3(3, 2)
# array(5)

引数をふたつ受け取り、それらの和と差を返す Functionoutputs複数の式をリストで渡せばよい。

f4 = theano.function(inputs=[a, b], outputs=[a + b, a - b])
f4(3, 4)
# [array(7), array(-1)]

引数をひとつ受け取り、別に定義した 共有変数との和を返す Function。共有変数は function 内で直接使ってよい。

s = theano.shared(2)
s.get_value()
# array(2)

f5 = theano.function(inputs=[a], outputs=a + s)
f5(3)
# array(5)

引数 / 返り値はなしで、別に定義した 共有変数を破壊的に 1 増やす Function。共有変数に対する更新は updates で定義。

inc = theano.function(inputs=[], outputs=None, updates={s: s + 1})
s.get_value()
# array(2)

inc()
# []

s.get_value()
# array(3)

引数をひとつ受け取り、共有変数との和を返す Function。同時に、共有変数を破壊的に 1 増やす。共有変数に対する更新は updates で定義。updates の処理は outputs の後に行われているようだ。

s.set_value(0)
f6 = theano.function(inputs=[a], outputs=a + s, updates={s: s + 1})
f6(2)
# array(2)

s.get_value()
# array(1)

f6(2)
# array(3)

また、引数はマッピングさせて使うこともできる。以下のように givens を定義すると、式中のシンボル s が 引数 xマッピングされる。

x = T.lscalar('x')
f7 = theano.function(inputs=[a, x], outputs=a + s, givens=[(s, x)])
f7(2, 5)
# array(7)

マッピングは定数でもよい。

f8 = theano.function(inputs=[a], outputs=a + s, givens=[(s, 3)])
f8(2)
# array(5)

引数が vector ならブロードキャストされる。

v = T.ivector()
f9 = theano.function(inputs=[v], outputs=v * 2)
f9([1, 1, 1]) 
# array([2, 2, 2], dtype=int32)

theano.scan

次に、Theano における 繰り返し処理に対応する theano.scanscan にはおおきく 以下 2 種類の動きがあり、混同するとわけわからなくなる。それぞれ明確な名前が付いているわけではなさそうだが、便宜上 区別したいので 以下ドキュメントの章題をもとに それぞれ Loop / Iteration と書く。

参考 scan – Looping in Theano — Theano 0.7 documentation

  1. Loop: ある関数 fn を、引数に対して n_steps 回 適用する。返り値は 長さ n_steps のベクトルとなり、 [fn(x), fn(fn(x))...] のような処理になる。

  2. Iteration: ある関数 fn を、シーケンス的な引数に対して適用する。返り値は 引数の各要素と同じ長さのベクトルとなり、 [fn(x) for x in ...] のような処理になる。

scan がどちらの処理を行うかは 引数をどのように渡すかによって決まる。

  • fn : Loop / Iteration において適用される式。
  • sequences : シーケンスとして Iteration 処理される引数。繰り返しのたびにシーケンスの 1, 2, 3... 番目の要素が順に fn へ渡る。
  • outputs_info : Loop 処理の初期値となる値。
  • non_sequences : シーケンスでない fn への引数。繰り返しのたびに同じ値が fn へ渡る。
  • n_steps : 繰り返し処理を行う回数。

このとき、scanfn に渡す引数は (最大で) 以下の 3 つになる。それぞれ、対応する引数がない場合は省略される (fn に渡される引数の数自体が変わる)。

  1. シーケンスの要素 (sequences が指定されている場合)
  2. 直前の繰り返し処理の結果 (outputs_info が指定されている場合)
  3. シーケンスでない引数 = non_sequencesそのもの (non_sequencesが指定されている場合)

そのため、scan 処理を書く場合の考え方は以下のようになると思う。

  1. 処理に対して適切な 引数 sequences, outputs_info, non_sequences が決まる
  2. fn に渡される引数が決まる -> fn の具体的な処理が決まる

まずは 単純な Loop 処理。ひとつ目の引数 5 を 2 倍する処理を 3 回繰り返す。結果は [5*2, 5*2*2, 5*2*2*2] のベクトルとなる。

n = T.iscalar('n')
result, updates = theano.scan(fn=lambda prior, nonseq: prior * 2,
                              sequences=None,
                              outputs_info=a,
                              non_sequences=a,
                              n_steps=n)

sf1 = theano.function(inputs=[a, n], outputs=result, updates=updates)
sf1(5, 3)
# array([10, 20, 40])

Loop 処理の最後の結果だけが欲しい場合は、theano.functionoutputs でベクトル末尾の要素を指定する。

sf2 = theano.function(inputs=[a, n], outputs=result[-1], updates=updates)
sf2(5, 3)
# array(40, dtype=int32)

ベクトルに対する Loop 処理。シンボルの型と引数が変わるだけ。

v = T.ivector('v')
result, updates = theano.scan(fn=lambda prior, nonseq: prior * 2,
                              sequences=None,
                              outputs_info=v,
                              non_sequences=v,
                              n_steps=n)
sf3 = theano.function(inputs=[v, n], outputs=result[-1], updates=updates)
sf3([1, 2, 3], 3)
# array([ 8, 16, 24], dtype=int32)

fn 内で non_sequences のみを利用したときの結果をみる。non_sequencesには 繰り返し回数にかかわらず同じ値が渡ってくるため、ループ回数 n_steps は結果に影響しなくなる。

result, updates = theano.scan(fn=lambda prior, nonseq: nonseq * 2,
                              sequences=None,
                              outputs_info=v,
                              non_sequences=v,
                              n_steps=n)
sf4 = theano.function(inputs=[v, n], outputs=result[-1], updates=updates)
sf4([1, 2, 3], 100)
# array([2, 4, 6], dtype=int32)

続けて Iteration 処理。以下は 渡された引数 5 と同じ長さのベクトル [0, 1, 2, 3, 4] をつくり、各要素に対して 2 を加算する。

outputs = T.as_tensor_variable(np.asarray(0))
result, updates = theano.scan(fn=lambda seq, prior: seq + 2,
                              sequences=T.arange(a),
                              outputs_info=outputs,
                              non_sequences=None)
sf5 = theano.function(inputs=[a], outputs=result, updates=updates)
sf5(5)
# array([2, 3, 4, 5, 6])

定数ではなく直前の値 (prior) を足すと以下のようになる。直前の結果に各要素が加算されるので、結果は [0, 0+1, 1+2, 3+3, 6+4] のベクトル。

outputs = T.as_tensor_variable(np.asarray(0))
result, updates = theano.scan(fn=lambda seq, prior: seq + prior,
                              sequences=T.arange(a),
                              outputs_info=outputs,
                              non_sequences=None)
sf6 = theano.function(inputs=[a], outputs=result, updates=updates)
sf6(5)
# array([ 0,  1,  3,  6, 10])

n_steps を指定した場合は シーケンスがその長さに達した時点で処理が終わる。

outputs = T.as_tensor_variable(np.asarray(0))
result, updates = theano.scan(fn=lambda seq, prior: seq + prior,
                              sequences=T.arange(a),
                              outputs_info=outputs,
                              non_sequences=None,
                              n_steps=3)
sf7 = theano.function(inputs=[a], outputs=result, updates=updates)
sf7(5)
# array([0, 1, 3])

複数のシーケンスに対して処理をする場合は、sequences に対して各シーケンスに対応するシンボルのリストを渡す。

a = T.ivector('a')
b = T.ivector('b')
result, updates = theano.scan(fn=lambda seq1, seq2: seq1 + seq2,
                              sequences=[a, b],
                              outputs_info=None,
                              non_sequences=None)
sf8 = theano.function(inputs=[a, b], outputs=result, updates=updates)
sf8([1, 3, 6], [2, 7, 8])
# array([ 3, 10, 14], dtype=int32)

複数の返り値を持たせる場合は theano.function と同様。

result, updates = theano.scan(fn=lambda seq1, seq2: [seq1 + seq2, seq1 - seq2],
                              sequences=[a, b],
                              outputs_info=None,
                              non_sequences=None)
sf9 = theano.function(inputs=[a, b], outputs=result, updates=updates)
sf9([1, 3, 6], [2, 7, 8])
# [array([ 3, 10, 14], dtype=int32), array([-1, -4, -2], dtype=int32)]

まとめ

theano.function, theano.scan の挙動を整理した。

scan 処理を書く場合は、

  1. 処理に対して適切な 引数 sequences, outputs_info, non_sequences が決まる
  2. fn に渡される引数が決まる -> fn の具体的な処理が決まる

scan の処理を読み解く場合は、

  1. まず引数 sequences, outputs_info, non_sequences を確認し、Loop / Iteration どちらなのかを見分ける
  2. fn に何が渡っているかがわかる -> fn の処理を読み解く

5/3追記 @xiangze さんが、scan の条件付き終了 (while) などについてエントリを書かれているので、こちらもご参照ください。

xiangze.hatenablog.com