Python Theano function / scan の挙動まとめ
勉強のため たまに Pylearn2
など Theano
を使ったパッケージのソースを眺めたりするのだが、theano.scan
の挙動を毎回 忘れてしまう。繰り返し調べるのも無駄なので、一回 整理したい。theano.scan
の動作は theano.function
が前提となるため、あわせて書く。
準備
import numpy as np import theano import theano.tensor as T
theano.function
まずは Theano
における関数にあたる Function
インスタンスを作成する theano.function
の基本的な挙動について。引数はいろいろあるが、特に重要と思われるのは以下の4つ。
inputs
:Function
への入力 (引数) に対応するシンボル。outputs
:Function
化される式。updates
: 共有変数を更新する式。givens
: 引数 (inputs
) -> 何かへのマッピングを行う辞書。
引数 ひとつを受け取って 1 インクリメントした値を返す Function
は以下のように作る。f1(3)
の 3 が 定義中のシンボル a
に対応する。inputs
には、引数がひとつ もしくは ない場合でもリストを指定すること。
a = T.lscalar('a') f1 = theano.function(inputs=[a], outputs=a + 1) f1 # <theano.compile.function_module.Function at 0x10d00fb90> f1(3) # array(4)
補足 theano
は内部的に 式をグラフ構造として処理している。その構造は debugprint
などで表示できる。
theano.printing.debugprint(f1) # Elemwise{add,no_inplace} [@A] '' 0 # |TensorConstant{1} [@B] # |a [@C]
補足 また、グラフ構造はダイアグラムとしても描画できる。IPython (Jupyter) 上で描画する場合は、
import IPython.display as display def plot(obj): svg = theano.printing.pydotprint(obj, return_image=True, format='svg') return display.SVG(svg) plot(f1)
引数をふたつ受け取りそれらの和を返す Function
は以下のようになる。
b = T.lscalar('b') f2 = theano.function(inputs=[a, b], outputs=(a + b)) f2(3, 4) # array(7)
また、各引数の既定値 や 別名は theano.Param
を通すことで指定できる。以下の例では、ふたつめの引数 b
が指定されない場合、既定値として 1 を使う。
f3 = theano.function(inputs=[a, theano.Param(b, default=1)], outputs=(a + b)) f3(3) # array(4) f3(3, 2) # array(5)
引数をふたつ受け取り、それらの和と差を返す Function
。outputs
に複数の式をリストで渡せばよい。
f4 = theano.function(inputs=[a, b], outputs=[a + b, a - b]) f4(3, 4) # [array(7), array(-1)]
引数をひとつ受け取り、別に定義した 共有変数との和を返す Function
。共有変数は function
内で直接使ってよい。
s = theano.shared(2) s.get_value() # array(2) f5 = theano.function(inputs=[a], outputs=a + s) f5(3) # array(5)
引数 / 返り値はなしで、別に定義した 共有変数を破壊的に 1 増やす Function
。共有変数に対する更新は updates
で定義。
inc = theano.function(inputs=[], outputs=None, updates={s: s + 1}) s.get_value() # array(2) inc() # [] s.get_value() # array(3)
引数をひとつ受け取り、共有変数との和を返す Function
。同時に、共有変数を破壊的に 1 増やす。共有変数に対する更新は updates
で定義。updates
の処理は outputs
の後に行われているようだ。
s.set_value(0) f6 = theano.function(inputs=[a], outputs=a + s, updates={s: s + 1}) f6(2) # array(2) s.get_value() # array(1) f6(2) # array(3)
また、引数はマッピングさせて使うこともできる。以下のように givens
を定義すると、式中のシンボル s
が 引数 x
でマッピングされる。
x = T.lscalar('x') f7 = theano.function(inputs=[a, x], outputs=a + s, givens=[(s, x)]) f7(2, 5) # array(7)
マッピングは定数でもよい。
f8 = theano.function(inputs=[a], outputs=a + s, givens=[(s, 3)]) f8(2) # array(5)
引数が vector
ならブロードキャストされる。
v = T.ivector() f9 = theano.function(inputs=[v], outputs=v * 2) f9([1, 1, 1]) # array([2, 2, 2], dtype=int32)
theano.scan
次に、Theano
における 繰り返し処理に対応する theano.scan
。scan
にはおおきく 以下 2 種類の動きがあり、混同するとわけわからなくなる。それぞれ明確な名前が付いているわけではなさそうだが、便宜上 区別したいので 以下ドキュメントの章題をもとに それぞれ Loop / Iteration と書く。
参考 scan – Looping in Theano — Theano 0.7 documentation
Loop: ある関数
fn
を、引数に対してn_steps
回 適用する。返り値は 長さn_steps
のベクトルとなり、[fn(x), fn(fn(x))...]
のような処理になる。Iteration: ある関数
fn
を、シーケンス的な引数に対して適用する。返り値は 引数の各要素と同じ長さのベクトルとなり、[fn(x) for x in ...]
のような処理になる。
scan
がどちらの処理を行うかは 引数をどのように渡すかによって決まる。
fn
: Loop / Iteration において適用される式。sequences
: シーケンスとして Iteration 処理される引数。繰り返しのたびにシーケンスの 1, 2, 3... 番目の要素が順にfn
へ渡る。outputs_info
: Loop 処理の初期値となる値。non_sequences
: シーケンスでないfn
への引数。繰り返しのたびに同じ値がfn
へ渡る。n_steps
: 繰り返し処理を行う回数。
このとき、scan
が fn
に渡す引数は (最大で) 以下の 3 つになる。それぞれ、対応する引数がない場合は省略される (fn
に渡される引数の数自体が変わる)。
- シーケンスの要素 (
sequences
が指定されている場合) - 直前の繰り返し処理の結果 (
outputs_info
が指定されている場合) - シーケンスでない引数 =
non_sequences
そのもの (non_sequences
が指定されている場合)
そのため、scan
処理を書く場合の考え方は以下のようになると思う。
- 処理に対して適切な 引数
sequences
,outputs_info
,non_sequences
が決まる fn
に渡される引数が決まる ->fn
の具体的な処理が決まる
まずは 単純な Loop 処理。ひとつ目の引数 5 を 2 倍する処理を 3 回繰り返す。結果は [5*2, 5*2*2, 5*2*2*2]
のベクトルとなる。
n = T.iscalar('n') result, updates = theano.scan(fn=lambda prior, nonseq: prior * 2, sequences=None, outputs_info=a, non_sequences=a, n_steps=n) sf1 = theano.function(inputs=[a, n], outputs=result, updates=updates) sf1(5, 3) # array([10, 20, 40])
Loop 処理の最後の結果だけが欲しい場合は、theano.function
の outputs
でベクトル末尾の要素を指定する。
sf2 = theano.function(inputs=[a, n], outputs=result[-1], updates=updates) sf2(5, 3) # array(40, dtype=int32)
ベクトルに対する Loop 処理。シンボルの型と引数が変わるだけ。
v = T.ivector('v') result, updates = theano.scan(fn=lambda prior, nonseq: prior * 2, sequences=None, outputs_info=v, non_sequences=v, n_steps=n) sf3 = theano.function(inputs=[v, n], outputs=result[-1], updates=updates) sf3([1, 2, 3], 3) # array([ 8, 16, 24], dtype=int32)
fn
内で non_sequences
のみを利用したときの結果をみる。non_sequences
には 繰り返し回数にかかわらず同じ値が渡ってくるため、ループ回数 n_steps
は結果に影響しなくなる。
result, updates = theano.scan(fn=lambda prior, nonseq: nonseq * 2, sequences=None, outputs_info=v, non_sequences=v, n_steps=n) sf4 = theano.function(inputs=[v, n], outputs=result[-1], updates=updates) sf4([1, 2, 3], 100) # array([2, 4, 6], dtype=int32)
続けて Iteration 処理。以下は 渡された引数 5 と同じ長さのベクトル [0, 1, 2, 3, 4]
をつくり、各要素に対して 2 を加算する。
outputs = T.as_tensor_variable(np.asarray(0)) result, updates = theano.scan(fn=lambda seq, prior: seq + 2, sequences=T.arange(a), outputs_info=outputs, non_sequences=None) sf5 = theano.function(inputs=[a], outputs=result, updates=updates) sf5(5) # array([2, 3, 4, 5, 6])
定数ではなく直前の値 (prior
) を足すと以下のようになる。直前の結果に各要素が加算されるので、結果は [0, 0+1, 1+2, 3+3, 6+4]
のベクトル。
outputs = T.as_tensor_variable(np.asarray(0)) result, updates = theano.scan(fn=lambda seq, prior: seq + prior, sequences=T.arange(a), outputs_info=outputs, non_sequences=None) sf6 = theano.function(inputs=[a], outputs=result, updates=updates) sf6(5) # array([ 0, 1, 3, 6, 10])
n_steps
を指定した場合は シーケンスがその長さに達した時点で処理が終わる。
outputs = T.as_tensor_variable(np.asarray(0)) result, updates = theano.scan(fn=lambda seq, prior: seq + prior, sequences=T.arange(a), outputs_info=outputs, non_sequences=None, n_steps=3) sf7 = theano.function(inputs=[a], outputs=result, updates=updates) sf7(5) # array([0, 1, 3])
複数のシーケンスに対して処理をする場合は、sequences
に対して各シーケンスに対応するシンボルのリストを渡す。
a = T.ivector('a') b = T.ivector('b') result, updates = theano.scan(fn=lambda seq1, seq2: seq1 + seq2, sequences=[a, b], outputs_info=None, non_sequences=None) sf8 = theano.function(inputs=[a, b], outputs=result, updates=updates) sf8([1, 3, 6], [2, 7, 8]) # array([ 3, 10, 14], dtype=int32)
複数の返り値を持たせる場合は theano.function
と同様。
result, updates = theano.scan(fn=lambda seq1, seq2: [seq1 + seq2, seq1 - seq2], sequences=[a, b], outputs_info=None, non_sequences=None) sf9 = theano.function(inputs=[a, b], outputs=result, updates=updates) sf9([1, 3, 6], [2, 7, 8]) # [array([ 3, 10, 14], dtype=int32), array([-1, -4, -2], dtype=int32)]
まとめ
theano.function
, theano.scan
の挙動を整理した。
scan
処理を書く場合は、
- 処理に対して適切な 引数
sequences
,outputs_info
,non_sequences
が決まる fn
に渡される引数が決まる ->fn
の具体的な処理が決まる
scan
の処理を読み解く場合は、
- まず引数
sequences
,outputs_info
,non_sequences
を確認し、Loop / Iteration どちらなのかを見分ける fn
に何が渡っているかがわかる ->fn
の処理を読み解く
5/3追記 @xiangze さんが、scan
の条件付き終了 (while
) などについてエントリを書かれているので、こちらもご参照ください。