"機械学習","信号解析","ディープラーニング"の勉強

読者です 読者をやめる 読者になる 読者になる

HELLO CYBERNETICS

深層学習、機械学習、強化学習、信号処理、制御工学などをテーマに扱っていきます

誤差逆伝搬法(バックプロパゲーション)とは

 はじめに

ニューラルネットのフレームワークを使うと、誤差逆伝搬は既に実装されているため、ほとんど意識すること無く使えてしまいます。言わばブラックボックスという状態です。ここで誤差逆伝搬法について学んでおくことで、ニューラルネットがいかにして学習を行っているのかを理解しましょう。

 

数式が相応に難解なため、一番最後まで飛んでしまっても構いません。

最終的に得られる式を見れば、誤差逆伝搬法の由来を理解できるかと思います。

 

ニューラルネットの基礎

ニューラルネットの順伝搬

1つの層での処理

まずニューラルネットワークといえば入力

 

{\bf x}=(x_1,...,x_d)^T

 

に対して以下のような処理を行います。

 

\bf z = Wx + b

 

a = f(z)

 

基本的にはたったコレだけです。

\bf WD×dの行列で、\bf bD×1の行列(ベクトル)です。

結果\bf zD×1次元の行列(ベクトル)になります。

fは何らかの活性化関数で、成分毎に作用し、aD×1次元の行列(ベクトル)です。

 

一括すれば

 

{\bf a} = f({\bf Wx + b})

 

という処理になります。ここで、

 

{\bf x'}=(x_0,x_1,...,x_d)^T

 

{\bf W'}=\bf (b,W)

 

と置き直せば

 

{\bf a} = f({\bf W'x'})

 

という処理に一括することができます。

x_0は単にx_0=1でありこうしておくことで、バイアス\bf bの計算を含んでしまうことができるというちょっとしたテクニックです。

以後、\bf W\bf xはこのような処理が施されていることとします。

 

多層での処理

前述までの

 

{\bf a} = f({\bf Wx})

 

に関して、これを1層目の処理と明記するために以下のよう表現にします。

 

{\bf x^{(2)}}=f(\bf W^{(1)}x^{(1)})

 

こうしておくと、x^{(2)}2層目への入力であるとハッキリ分かります。

(単に見やすさの問題で、文字は実際は何でもいいのです)

 

するとl層とl+1層間での処理については

 

{\bf x^{(l+1)}}=f(\bf W^{(l)}x^{(l)})

 

と表すことができ、層がL個あるならば、結局上記の式をl=1からl=Lまで繰り返すという処理になります。これが順伝搬と呼ばれる処理です。

 

成分を明示的に書けば

 

x_j^{(l+1)}=f(\sum_i w_{ji}^{(l)}x_i^{(l)})

 

となります。(この処理がわからない場合は下記の記事)

s0sem0y.hatenablog.com

順伝搬のまとめ

l+1層とl層での処理は以下で表され

 

x_j^{(l+1)}=f(\sum_i w_{ji}^{(l)}x_i^{(l)})

 

これをl=1からl=Lまで繰り返す。

線形変換→活性化を小分けにすると

 

 

z_j^{(l+1)}=\sum_i w_{ji}^{(l)}x_i^{(l)}

 

x_j^{(l+1)}=f(z_j^{(l+1)})

 

と表すことができます。

 

ニューラルネットワークの学習

勾配法

学習とは損失関数を減少させるように\bf Wを変更していく処理です。

損失関数E({\bf W})を減少させるためには、これの勾配を計算し、勾配は関数を増加させる方向を示しているため

 

{\bf W} ← {\bf W} - ε\frac{\partial E({\bf W})}{\partial {\bf W}}

 

と更新をします。

εは学習率と呼ばれ、どれだけ大きく更新を行うかを決めるパラメータです。

十分損失関数が小さくなるまで更新を繰り返すことで学習を行います。この方法を勾配法と言い、ニューラルネットの基礎的な手法となっています。

 

例えば回帰ではニューラルネットの出力\bf yとし、目標値(ラベル)を\bf tとして、損失関数は

 

E({\bf W})=\frac{1}{2}({\bf y-t})^2

 

のように設定されます。データがN個あれば、それぞれのデータの損失E_n

 

E_n({\bf W})=\frac{1}{2}({\bf y_n-t_n})^2

 

を考えて、全てのデータの損失の和

 

E({\bf W})=\sum_{n=1}^N E_n = \frac{1}{2}\sum_{n=1}^N ({\bf y_n-t_n})^2

 

を損失関数の具体的な計算とします。

つまり出力\bf yがニューラルネットのパラメータ\bf Wに依存する値となっており、Wを色々調整してEを減少させる(つまり\bf y\bf tに近づける)というのがニューラルネットの学習です。

 

s0sem0y.hatenablog.com

 

微分の難しさ

問題は、勾配を如何にして計算するかです。

行列\bf WでスカラーEの微分というのは、\bf Wと同じサイズの行列になり、(j,i)成分が\frac{\partial E}{\partial w_{ji}}となります(単に各成分で微分していくだけ)。ベクトルも行列の一種(例えばd×1の行列)のように考えて、同じように成分毎に微分するだけです。

 

結局知りたいのは、w_{ji}での微分であって、これが計算できれば\bf Wの更新も上手く行えるという状況です。しかし、歴史上これの計算が非常に手間取ったため、ニューラルネットは一時期姿を消しました。

理由は

 

{\bf x^{(l+1)}}=f(\bf W^{(l)}x^{(l)})

 

という処理を、例えば2回行えば、ニューラルネットの出力\bf y

 

{\bf y}={\bf x^{(3)}}=f(\bf W^{(2)}f(\bf W^{(1)}x^{(1)}))

 

という計算になります。成分毎に書けば

 

{y_j}={x_j^{(3)}}=f(\sum_i w_{ji}^{(2)}f(\sum_i w_{ji}^{(1)}x_i^{(1)}))

 

となります。パラメータを更新するためにはw_{ji}^{(1)}w_{ji}^{(2)}すべての微分を計算しなければなりません。こいつはかなり厄介です。

これを解決したのが誤差逆伝搬法です。

 

誤差逆伝搬法

 

問題

 

ニューラルネットワークの出力は以下で表されます。

 

{y_j}={x_j^{(3)}}=f(\sum_i w_{ji}^{(2)}f(\sum_i w_{ji}^{(1)}x_i^{(1)}))

 

通常データは複数あるわけでn個目のデータであることを明示すれば

 

{y_{nj}}={x_{nj}^{(3)}}=f(\sum_i w_{ji}^{(2)}f(\sum_i w_{ji}^{(1)}x_{ni}^{(1)}))

 

であり、更に損失はベクトルの表現で以下ですから

 

E({\bf W})=\sum_{n=1}^N E_n = \frac{1}{2}\sum_{n=1}^N ({\bf y_n-t_n})^2

 

これも成分を明示すると

 

E({w_{ji}})=\sum_{n=1}^N  E_n = \frac{1}{2}\sum_{n=1}^N \sum_{j}(y_{nj}-t_{nj})^2

 

となります。これでデータの個数と成分も明示した表現に書き換えることができました。

最後のニューラルネットの出力を誤差関数に代入すれば

 

E({w_{ji}})=\frac{1}{2}\sum_{n=1}^N \sum_{j}(f(\sum_i w_{ji}^{(2)}f(\sum_i w_{ji}^{(1)}x_{ni}^{(1)}))-t_{nj})^2

 

となります。これを存在する全てのw_{ji}で1つずつ微分をしていこうというわけですが、もうやりたくないですよね。

w_{ji}に関して非常に複雑な合成関数の形をしており、普通に1つ1つ微分していったんではとんでもない計算になってしまいそうなのが目に見えています。しかも層の数を2つとかなり制限していてコレです。もっと多層にすれば、入れ子がどんどん増えていき手におえません。

 

 

誤差逆伝搬法

まずは損失の和を計算する

まず以下の式に関して、データが1つであるとしましょう。

複数あっても結局後で和を取ればいいだけなので問題ありません。データ点1つに対して計算ができるようになればそれでいいのです。

ちなみにデータ1つを適当に選んで学習するのが確率的勾配法、データを複数まとめるのがミニバッチ法、データをすべて使うのがバッチ法であり、単にいくつのデータに対して損失の和を取るかで区分けされているにすぎません。従って、データ1つに対する誤差逆伝搬が分かればそれで十分です。

 

ということで、以下の損失関数からnは消し去ってしまいましょう。

 

E({w_{ji}})=\frac{1}{2}\sum_{n=1}^N \sum_{j}(f(\sum_i w_{ji}^{(2)}f(\sum_i w_{ji}^{(1)}x_{ni}^{(1)}))-t_{nj})^2

 

特に本質的な話ではありませんが、とりあえず見やすくはなったと思います。

 

 

E({w_{ji}})=\frac{1}{2}\sum_{j}(f(\sum_i w_{ji}^{(2)}f(\sum_i w_{ji}^{(1)}x_i^{(1)}))-t_{j})^2

 

これの計算を複数のデータで行い、和を取ればいいということになります。

今回は回帰を想定していますが、分類ならば交差エントロピーを使うだけで話は一緒です。

 

微分の連鎖律

一般的な議論をするために損失関数は単にE({w_{ji}})と記述することにします。

 目標は以下の計算を効率化することです。

 

\frac{\partial E({w_{ji}})}{\partial {w_{ji}}}

 

今は明示的には見えていませんが、Eはとてつもなく複雑なので簡単には計算できません。しかし、微分には連鎖率というものがあり以下のように(あたかも分数のごとく)計算することが許されています。

 

\frac{\partial E({w_{ji}})}{\partial {w_{ji}}}=\frac{\partial E({w_{ji}})}{\partial {a_{j}}} \frac{\partial a_{j}}{\partial {w_{ji}}}

 

今は適当にa_{j}を間に挟みましたが、もっと一杯挟んでも良いです。

ニューラルネットでは順伝搬の最後L層で

 

z_j^{(L+1)}=\sum_i w_{ji}^{(L)}x_i^{(L)}

 

という計算が行われていることに着目します。

L層目への入力が\bf x^{(L)}であるように添字を付けたのでL+1という添字が出てきてきますが、気持ち悪かったら、添字の付け方を0から開始したりL層目の出力を基準につけるなどすればいい)

 

\frac{\partial E({w_{ji}})}{\partial {w_{ji}^{(L)}}}=\frac{\partial E({w_{ji}})}{\partial {z_{j}^{(L+1)}}} \frac{\partial z_{j}^{(L+1)}}{\partial {w_{ji}^{(L)}}}

 

という連鎖律を使います。ここで\frac{\partial z_{j}^{(L)}}{\partial {w_{ji}^{(L1)}}}の計算は簡単であるということに注目してください。単に

 

\frac{\partial z_{j}^{(L+1)}}{\partial {w_{ji}^{(L)}}}=x_i^{(L)}

 

になります。従って、w_{ji}^{(L)}に関する微分は

 

\frac{\partial E({w_{ji}})}{\partial {w_{ji}^{(L)}}}=\frac{\partial E({w_{ji}})}{\partial {z_{j}^{(L+1)}}} \frac{\partial z_{j}^{(L+1)}}{\partial {w_{ji}^{(L)}}}=\frac{\partial E({w_{ji}})}{\partial {z_{j}^{(L+1)}}} x_i^{(L)}

 

と求まりました。そうなると、知りたいのは\frac{\partial E({w_{ji}})}{\partial {z_{j}^{(L+1)}}}という部分の計算になります。これで問題が楽になったのか、今は定かではありませんが、後に見るように驚異的なアルゴリズムの導出に繋がります。

 

出力での損失の微分

知りたい計算が\frac{\partial E({w_{ji}})}{\partial {z_{j}^{(L+1)}}}にすり替わったので、これを一々書くのがめんどうということで、以下のように文字でおいてしまいます。

 

δ_j^{(L)}=\frac{\partial E({w_{ji}})}{\partial {z_{j}^{(L+1)}}}

 

すると損失関数の微分は

 

\frac{\partial E({w_{ji}})}{\partial {w_{ji}^{(L)}}}=δ_j^{(L)} x_i^{(L)}

 

で表され、見かけ上随分簡単になりました。

とにかくδ_j^{(L)}を計算する手立てさえ見つかれば、良いという状況ですが、これは非常に簡単な問題です。回帰の問題においては最後の層のことだけを考えれば、

 

y_j = x_j^{(L+1)} = f(z_j^{(L+1)})

 

E({w_{ji}})=\frac{1}{2}\sum_{j}(y_{j}-t_{j})^2=\frac{1}{2}\sum_{j}(f(z_{j}^{(L+1)})-t_{j})^2

 

となっているのです。これをz_j^{(L+1)}で微分するのはたやすく

 

δ_j^{(L)}=\frac{\partial E({w_{ji}})}{\partial {z_{j}^{(L+1)}}}=\{ f(z_{j}^{(L+1)})-t_{j}\}・f'(z_{j}^{(L+1)})

 

となります。回帰の問題では出力層の活性化関数fは恒等変換なので

 

δ_j^{(L)}=z_{j}^{(L+1)}-t_{j}=y_j-t_j

 

ととんでもなく簡単になります。しかも、分類の問題で損失関数が交差エントロピーで、活性化関数がソフトマックス関数となっていても、計算すると上記のように出力と目標値の差という単純な形に計算できます。

 

結局出力での損失の微分は

 

\frac{\partial E({w_{ji}})}{\partial {w_{ji}^{(L)}}}=δ_j^{(L)} x_i^{(L)}=(y_j-t_j)・x_i

 

と非常に単純な形で求まりました。

 

中間層での損失の微分

ここまでは出力層のみに着目していたので、複雑な合成関数の形を意識すること無く解くことができていました。上記では損失関数の微分を具体的に計算するために、出力層に着目しましたが。しかし、下記のような連鎖規則はa_{j}=z_{j}^{(L+1)}でなくとも良く、もっと一般的にa_j=z_{j}^{(l+1)}を間に挟んでも良いのです(問題はそのような連鎖規則をしたときに計算しやすくなるのかどうか)。

 

\frac{\partial E({w_{ji}})}{\partial {w_{ji}}}=\frac{\partial E({w_{ji}})}{\partial {a_{j}}} \frac{\partial a_{j}}{\partial {w_{ji}}}

 

\frac{\partial E({w_{ji}})}{\partial {w_{ji}^{(l)}}}=\frac{\partial E({w_{ji}})}{\partial {z_{j}^{(l+1)}}} \frac{\partial z_{j}^{(l+1)}}{\partial {w_{ji}^{(L)}}}=δ_j^{(l)}x_j^{(l)}

 

こうなったときにδ_j^{(l)}=\frac{\partial E({w_{ji}})}{\partial {z_{j}^{(l+1)}}} を計算する手立てさえあればいいのです。ここで再度連鎖規則を用います。

 

δ_j^{(l)}=\frac{\partial E({w_{ji}})}{\partial {z_{j}^{(l+1)}}}=\sum_k \frac{\partial E}{\partial z_{k}^{(l+1)}} \frac{\partial z_{k}^{(l+1)}}{\partial z_j^{(l)}}

 

ここで最右辺の第一因子はδ_k^{(l+1)}になっています。

第二因子は、順伝搬においてz_k^{(l+1)}=f' \left( \sum_k w_{kj}z_j^{(l)} \right)であるので

 

 

 \frac{\partial z_{k}^{(l+1)}} {\partial z_j^{(l)}} = f' \left( z_j^{(l)} \right) w_{kj}^{(l+1)}

 

と微分が簡単に求まり、まとめると

 

δ_j^{(l)}=f' \left(z_j^{(l)} \right) \sum_k w_{kj}^{(l+1)} δ_k^{(l+1)}

 

と求まります。

この導出された式を整理してみましょう。

 

右辺に関して

 

z_j^{(l)}という値は順伝搬の際に既に得られている値です。

活性化関数fも既知ですから、その微分もすぐに求まります。

w_{kj}^{(l+1)}も現在の重みの値を使えばよく、既知となっています。

わからないのはδ_k^{(l+1)}の値だけです。

しかし上記の式によれば

 

δ_j^{(l+1)}=f' \left(z_j^{(l+1)} \right) \sum_k w_{kj}^{(l+2)} δ_k^{(l+2)}

 

という具合に層の添字を変更すれば値を求めることができ、かつ、出力層でのδ_kに至っては、y_k-t_kと値が簡単に求まります。

すなわち、出力層から始めて、順番にδの値を求めればよいのです。

 

 

 

式を見直す

ニューラルネットの順伝搬法の式を見てみましょう。

 

z_j^{(l+1)}=\sum_i w_{ji}^{(l)}x_i^{(l)}

 

x_j^{(l+1)}=f(z_j^{(l+1)})

 

wによる重み付け和を行った後に、活性化関数を作用させるという流れです。

 

一方で誤差逆伝搬法において、求めなければいけないのはδだけであることが、連鎖規則により明らかになりました。そしてそのδを求める式は以下で表されます。

 

δ_j^{(l)}=f' \left(z_j^{(l)} \right) \sum_k w_{kj}^{(l+1)} δ_k^{(l+1)}

 

wによる重み付け和を行った後に、活性化関数の値を掛けています。

活性化関数の使われ方は違えど、順伝搬においては「前の層の値の線型結合」を使い、逆伝搬では「次の層の値の線型結合」を使っています。

 

f:id:s0sem0y:20170308060835p:plain

 

 

まるでニューラルネットは、データを与えたら前方向にデータを伝搬していき、出力層で誤差を計算したら、その誤差を今度は後ろ向きに流して返してくれるかのようです。

便宜的にδのことを「誤差」と表現し、この手法を誤差逆伝搬法と呼びます

 

δは何のために求めたかというと、誤差関数Ewによる微分を求めるためでした。実際には誤差逆伝搬法とは微分を求めるための高速計算手法であり、高い汎用性を有しています。

 

誤差逆伝搬法により微分を求めたら、勾配法により学習を行うというのがニューラルネットの学習手法となります。

 

s0sem0y.hatenablog.com

s0sem0y.hatenablog.com