HELLO CYBERNETICS

深層学習、機械学習、強化学習、信号処理、制御工学、量子計算などをテーマに扱っていきます

Pytorchで遊ぼう【データ成形からFNNまで】

 

 

follow us in feedly

f:id:s0sem0y:20170925014700j:plain

 

はじめに

最近Chainerの素晴らしさを再認識し、TensorFlow+KerasからChainerに舞い戻ってきました。一方で、世界的にはPytorchの活躍が見られることも認識していました。

 

ChainerをフォークしたPytorchが一体どんなものなのか気になり、ここ数日触れている内にPytorchも素晴らしい(Chainerが素晴らしいんだから当たり前だけど)ことが分かってきました。

 

Chainerに触れたことがある人ならば、それほど違和感なくPytorchを操ることができるでしょう。今はChainerの方が便利ですが、開発参加者の数を考える上では今後どうなるかわかりません。ここで少しだけPytorchで遊んでみることにしましょう。

 

Pytorch

PytorchはTorchというテンソルを扱うライブラリを基本にしています。Lua言語で使われてきたTorch7には根強いファンもいましたが、これがPythonで使える形となり、Pytorchが発表されて以降、徐々に普及してきています。

 

f:id:s0sem0y:20170925025953p:plain

(詳しくは以下参照) 

s0sem0y.hatenablog.com

 

 

TorchはほぼNumpyと同様の操作が可能です。

 

github.com

 

多少メソッドの命名などに違いはあれど、だいたいのことはできます(でもNumpyの方が高機能かな)。

 

PytorchではTorchを基本に携えて、ニューラルネットワークに必要な様々な計算処理、計算過程の保存(計算グラフの構築)を実装してあります。

 

これらの実装はChainer同様、Pythonにより行われているため、ニューラルネットの構築や計算過程を独自に作りたい場合にも、Pythonの機能をフル活用できます。(もちろんDefine by Runなので入力データに応じて計算グラフを動的に変更することも可能です。)

 

 

 

簡単な例題で遊ぶ

今回はPytorchを使ってフィードフォワードニューラルネットワークを実装し、分類問題を解いてみます。Torchを使うということで、普段numpyに慣れている人にとっては少し不慣れな操作も必要になります。

 

そこで今回はnumpyでデータを作るところから開始していきます。

 

データを作る

まずは以下のようにNumpyでこれから分類するデータを作ります。

 

f:id:s0sem0y:20170927003216p:plain

 

図を見てもらえば分かる通り分類ができそうです(そうなるように生成した)。これは要するに3種類のデータが別々の正規分布から出てきているということです。

 

実問題でも、本来分類が可能であるようなデータというのは、きっと背後に潜む確率分布には差異があるはずです。高次元のデータでは可視化ができませんが、ともかく手持ちのデータは背後には何らかの違いがあると仮定して、機械学習を行うことになります。

 

Torchへ変換する

まずはデータをひとつの配列に収めてしまいましょう。その後、まとめてTorchに変換します。

numpyからtorchへの変換は、関数として準備されているため非常に簡単です。

 

f:id:s0sem0y:20170927003919p:plain

 

今、データ数が300で、特徴量が2のtorch型のデータを獲得することが出来ました。同じようにラベルも変換しておきます。ここで注意点があります。

 

Pytorchでは入力データがfloat型、target(ラベル)はlong型という決まり

他クラス分類のラベルは、一次元配列でクラスの番号を格納する(one_hotでない)

 

の2点は確実に抑えておきましょう。

1つ目の方に関しては、私もハマってしまったところです。エラーメッセージをコピペで検索すれば英語のサイトで同じような問題にハマっている人がいたため、すぐに解決策は見つかりました。

2つ目に関してはChainerユーザーにとっては普通のことでしょう。KerasやTensorFlowユーザーは、多くの場合one_hotで与えることが多いのではないかと思います。ココラヘンはフレームワークごとに実装が異なるため、注意が必要です。

 

ミニバッチトレーニングのためのデータ準備

あとはニューラルネットワークを作って、上記のデータをぶち込んでやれば学習は進みます。今回のデータはそれでも良いのですが、もっと一般的な用途のためにミニバッチトレーニングの方法を解説します。

 

ミニバッチトレーニングでは、手元の全てのデータからランダムにデータを複数取り出し、パラメータの更新に使います(故に確率的勾配法と呼ばれる)。ココでは、全ての手持ちのデータからランダムにトレーニングデータとターゲットラベルをセットで取り出す準備をします。

 

f:id:s0sem0y:20170927010058p:plain

 

TensorDatasetは、train_torchとtarget_torchを入れてやると、1つ1つのデータとそれに対応するターゲットラベルを対にして格納してくれます。train_set[i]にはi番目のトレーニングデータとラベルがタプルに格納されるという形です。

 

このようにしておいたtrain_setを次はDataLoaderに渡してやります。batch_sizeを指定することで、全てのデータから一度に何個取り出すかを決めることができます。またトレーニング時にはshuffle=Trueにしておくことで、ランダムに取り出すことを可能にしてくれます(Chainer

にも全く同様の機能はありますね。Pythonのイテレーターです。)。

 

 

このtrain_loaderを使うことで、具体的には以下のように使えます。

f:id:s0sem0y:20170927010523p:plain

 

場所の都合上途中で切ってしまっていますが、batch_size個のtrainとtargetがtorch.Tensorが取り出されているのが分かります。上記のコードでは、これが、全てのデータを総なめするまで続くわけです(だから学習時のコードに使える)。

(学習時のコード例)

f:id:s0sem0y:20170927011235p:plain

 

ニューラルネットを作る

隠れ層4つで隠れ層のユニットを全て100にしたフィードフォワードニューラルネットワークを書きます。基本的にはChainerとほとんど変わりません。

 

f:id:s0sem0y:20170927011939p:plain

 

ただし、線形変換が何次元から何次元なのか(つまり行列のサイズがいくつなのか)を明示しなければなりません。Chainerならば各層の入力がいくつなのかを明示する必要はなく、「None」で指定しておいても構いません。

それがPytorchではできないので、しっかりと値を明示しておきましょう。

 

また、回帰に使うのか分類に使うのか、いずれにしても出力に活性化関数を掛けないでください。

 

回帰は出力の活性化がが恒等関数なのが当たり前ですが、分類では通常softmaxが使われます。しかし、これは計算機の都合上、ネットワークがロスの正しい把握をすることを邪魔するという報告もあるため、使われない傾向になっています(これはPytorchに限ったことじゃあないんですが)。

 

ただ紛いなりにも、softmaxを出力層に持ってくることは理屈的に意味があるので(最尤推定)、これを崩さないように工夫はされています。通常は代わりにloss関数の方をいじってしまっているので、ネットワークの方で変なことをしないようにしてください(フレームワークが提供するloss関数はsoftmaxを掛けないことを前提に設計されている)。

 

 

最適化関数をセットして学習

最適化関数は既にいろいろ準備されています。今回は交差エントロピーを用います。

学習率は通常よりかなり高めですが、とりあえず回すだけなので良しでしょう。

f:id:s0sem0y:20170927014020p:plain

 

学習率が高いため、やはり後半lossが上がることが出てきてしまっています。損失関数の形状によっては発散するということも起こってくるため注意が必要です。

Pytorchのポイントとしては、Chainerではlossをネットワークの方で書いてしまって、モデルの出力として直接lossを得ることが出来ました(また、L.Classifierを使うことも可能)。

 

Pytorchでは、ロスを計算するクラスにネットワークの出力をいれてやる必要があります。

 

また、lossはPytorchのVariableとして帰ってくるので、これをloss.data[0]で数値として見る必要があります。loss.dataにはtorchのfloat型が入っており、これはformatで取り出せないので注意してください(ChainerならVariable.dataがnumpyなのでそれでOKだが)。

 

 

 

最後に

とりあえずつかれたのでココまで。ここまでできていれば、for文の中に検証データを流すことも簡単でしょう。Accuracyの方はChainerのように計算する関数が準備されていないため、自分で作ってやる必要があります。まだまだv0.2で発展途上ですが今後に期待です。

 

 

余力がアレばAccuracy出したり評価するところまで書こうかな。

 

 

まあ、だいたい以下のチュートリアル見れば分かるんですが。

github.com