"機械学習","信号解析","ディープラーニング"の勉強

HELLO CYBERNETICS

深層学習、機械学習、強化学習、信号処理、制御工学などをテーマに扱っていきます

tensorflowの計算グラフにif文を入れる

 

 

f:id:s0sem0y:20170707094814j:plain

 

はじめに

TensorFlowを使いこなすためには、TensorFlowの勉強をする必要があります。

TensorFlowではPythonで計算グラフを構築しておき、あとの演算処理はTensorFlow(C++)に任せることになります。演算処理の部分にPythonの制御構造(for文やif文)を入れることはできないため、制御構造を持つ計算グラフそれ自体を作ってやる必要が出てきます。

 

今回はTensorFlowの計算グラフに制御構造を入れる2つの関数を紹介します。

 

if文

TensorFlowの計算グラフにif文を入れるためには「tf.cond」を使います。tf.condの使い方を見る前に、これと同じ動作をするPythonのプログラム例を見ましょう。

 

一方の値を返す関数

def cond(bool, x, y):
if bool:
return x
else:
return y

 

上記のコードがどのような処理を行うのか分かるでしょうか?cond関数はboolと2つの引数を受け取り、受け取ったboolにしたがってx,yの一方の引数のみを返します。

 

 

print(cond(True, 5, 10))

 

このようなcond関数に対して、Trueをセットし、引数として5と10を与えます。この出力はすぐに5だと分かるでしょう。もしもboolをFalseにセットすれば10を返してくることも分かります。

 

こんなプログラム、わざわざ関数で定義してやる必要あるのか?と思うかもしれません。Pythonならば必要ありません。

 

しかしTensorFlowを扱う場合には、このように全て関数の処理をしていると考えると非常に見通しが良くなります(宣言的言語と関数型言語の関係性は深いです)。

 

一方の関数を返す関数

次に以下のような少し複雑にしたコードを見ましょう

 

 

 

def add(a, b):
return a + b

def multiply(a, b):
return a * b

def cond(bool, func1, func2):
if bool:
return func1
else:
return func2

 

さあ少し複雑になったようですが、それでも簡単です。cond関数はboolを引数に、func1かfunc2のどちらを実行するかを選んでいます。これに対して、

 

print(cond(True, add(1,2), multiply(1,2)))

 

を実行した場合、add(1, 2)が実行されるのが分かるでしょうか。すなわち出力は3になります。boolの部分をFalseにすればmultiply(1, 2)が実行され出力は2になるでしょう。

 

上記のPythonのプログラム例が分かったのであれば、TensorFlowで計算グラフ(すなわち関数)を構築してやることは簡単です。実を言うと、まさに上記で構築したcond関数に相当する計算グラフが、tf.condになります。

 

tf.condは第1引数にtf.boolを、第2,3引数に関数を受け取ります。そして、tf.boolの値にしたがっていずれかの関数を実行します(すなわち、上記で見たPythonのcond関数と全く同じ動作を、「計算グラフ内で」実現します)。

 

 

x = tf.cond(True, lambda: 1+2, lambda: 1*2)

print(tf.Session().run(x))

 

これはもちろん計算グラフとして構築されているので、Sessionを動かしてやらねばなりません。そうすると期待通り、3という出力が得られるはずです。

 

注意としては、以下のコードは動作しないということです。

 

x = tf.cond(True, 5, 10)

print(tf.Session().run(x))

 

tf.condが第1引数を見て、第2か第3引数のいずれかを返してくるのだと考えた場合は、このコードなら5を返してくれそうです。しかし実際にはtf.condの第2,第3引数は関数でなければなりません。したがって定数を返す関数を作ってやる必要があり、望みのことを実現するためには

 

 

x = tf.cond(True, lambda:5, lambda:10)

print(tf.Session().run(x))

 

とやる必要があります。これならちゃんと動いてくれます。

(一応説明の都合上boolという名前を変数に付けましたが、もっとちゃんとした名前つけましょう。)

 

 

実用上の形式

 実際には引数であるboolとa,bは後から与えたい場合があるので、プレースホルダーを使って以下のように実装しておくと良いでしょう。

 

boolian = tf.placeholder(dtype=tf.bool)
a = tf.placeholder(dtype=tf.float32)
b = tf.placeholder(dtype=tf.float32)

x = tf.cond(boolian, lambda: a+b, lambda: a*b) print(tf.Session().run(x, feed_dict={boolian: True, a: 1, b: 2}))

 

プレースホルダーにはセッションを走らせるときに後から値をフィードしてやることになります。

また、以下のようにプレースホルダーに与えた2つの数値を単に足すか、線形結合するかを分岐させることもできます。このoutputsの部分がとあるニューラルネットの出力でも良いです。

 

boolian = tf.placeholder(dtype=tf.bool)
a = tf.placeholder(dtype=tf.float32)
b = tf.placeholder(dtype=tf.float32)
w1 = tf.constant(5.)
w2 = tf.constant(10.)
outputs = a*w1 + b*w2

x = tf.cond(boolian, lambda: a+b, lambda: outputs) print(tf.Session().run(x, feed_dict={boolian: True, a: 1, b: 2}))

 

バッチデータごとにif文の判定を変える

もしかしたら、バッチデータのi番目とj番目では処理を変えたいということがあるかもしれません。Tensorflowでは通常1つずつデータを流すわけではなく、多くのデータを同時に流して並列処理を行います。そのバッチデータのインデックスに応じて処理を変えたい場合はtf.whereが使えます。

 

a = tf.constant([5, 4])
b = tf.constant([2, 2])

p = tf.constant([True, False])

x = tf.where(p, a + b, a * b)

print(tf.Session().run(x))

 

例えば上記の処理では、

[7,8]

が出力されます。これはa,b,pのインデックスに対応付けて、処理がそれぞれTrueとFalseで異なっている結果です。

 

こっちのtf.whereは関数を引数として受け取るわけではなく、直接何らかの値を受け取って動作します。たとえデータが単なるスカラーだったとしても、tf.whereは動作できるので、今回挙げた簡単な例の全てはtf.whereで操作可能であり、

 

boolian = tf.placeholder(dtype=tf.bool)
a = tf.placeholder(dtype=tf.float32)
b = tf.placeholder(dtype=tf.float32)

x = tf.where(boolian, a+b, a*b) print(tf.Session().run(x, feed_dict={boolian: True, a: 1, b: 2}))

 

などとしてもしっかり動きます。(tf.condは関数を、tf.whereは直接値を引数に取るので、condではlambda:でパックしていたが、その必要はなくなる)

 

 

 

 

 

 

 

 

 

s0sem0y.hatenablog.com