"機械学習","信号解析","ディープラーニング"の勉強

HELLO CYBERNETICS

深層学習、機械学習、強化学習、信号処理、制御工学などをテーマに扱っていきます

ニューラルネットワークによる学習の停滞はどこから生ずるか

 

 

現在機械学習ではディープラーニングの活躍が目立っています。

その基礎はニューラルネットワークの学習にあり、この学習を知っているのとそうでないのとでは、各手法に関する理解度が大きく異なってくるものと思います。

今回はニューラルネットワークの学習における不思議の1つ、「学習の停滞」の原因について述べてみたいと思います。

 

 

はじめに

ニューラルネットワークでは学習が停滞したと思いきや、しばらく学習を根気強く続けていると、学習が進むときが訪れます。既にニューラルネットワークのプログラムを動かしたことがある人ならば経験したことがあるでしょう。なぜに学習が停滞し、そして再び進むということが起こるのでしょうか。 今回はそのことについて述べていきます。

 

勾配法

ニューラルネットは通常誤差逆伝播法によって学習を行います。ご存知の通り、これは損失関数のパラメータによる勾配を求め、損失関数を減少させるようにパラメータを更新していく勾配法の一種です。

勾配法

勾配法では、損失関数に対するパラメータについて

 

{\bf w} ← {\bf w} -η\frac {\partial L}{\partial {\bf w}}

 

と更新することで、損失関数が先程よりは幾分か減少した\bf wに変更することを繰り返していきます。ηは学習率と呼ばれ、どれくらい大きく進ませるかを決定するものです。

 

具体的な例題

勾配法とは損失関数を減少させる方向にパラメータを少しずつ更新する至ってシンプルな手法です。最も簡単な例題を考えて、パラメータ{\bf w} = (w_1,w_2)^Tによる損失関数が以下で表されるとしましょう。

 

L({\bf w}) = |{\bf w}|^2 = w_1^2 + w_2 ^2

 

損失関数が最小になるのは|{\bf w}|=0⇔w_1=w_2=0だとすぐにわかります。それは二乗が必ず正の値になるので、0以外にありえないとも考えられますし、あるいは微分して0になるところだという考え方もできます。勾配法では後者の考え方、「微分して0になるところ」というのが重要です。

 

具体的に勾配法を見てみましょう。

 

損失関数の\bf wによる微分\frac {\partial L}{\partial {\bf w}}は、スカラーLをベクトル{\bf w} = (w_1,w_2)^Tで微分しており、これは以下で定義されます。

 

\frac {\partial L}{\partial {\bf w}}=(\frac{\partial L}{\partial w_1} ,\frac{\partial L}{\partial w_2})^T

 

スカラーをベクトルの各成分で微分したベクトルになります。このベクトルは、損失関数が上昇する方向を示しているため、勾配法では、

 

{\bf w} ← {\bf w} -η\frac {\partial L}{\partial {\bf w}}

 

と値を更新します。これにより損失関数が先程よりは幾分か減少した\bf wが得られるはずです。もちろん、ある\bf wから少しずつ変更していくので、\bf wに初期値が必要になります。今回はとりあえず初期値を{\bf w} = (2,3)^Tとして始めてみましょう。上記の損失関数の設定で勾配を計算すると

 

\frac {\partial L}{\partial {\bf w}}=(\frac{\partial L}{\partial w_1} ,\frac{\partial L}{\partial w_2})^T = (2w_1,2w_2)^T

 

となっていますから更新は

 

{\bf w} ← (2,3)^T -η(4,6)^T

 

です。運良く学習率をη=0.5と設定していた場合は直ちに{\bf w} = (0,0)^Tが得られます。このときの勾配は(0,0)^Tと計算され、以降更新は行われません。

仮に学習率をη=0.1と設定していた場合は

 

{\bf w} ← (2,3)^T - 0.1(4,6)^T = (2 - 0.4,3 - 0.6)^T =(1.6,2.4)^T

 

と新しい値が得られます。この値における勾配を再び計算すると

 

\frac {\partial L}{\partial {\bf w}}=(2w_1,2w_2)^T=(3.2,4.8)^T

 

と求まるので、再び同じように更新をしていきます。

 

{\bf w} ← (1.6,2.4)^T - 0.1(3.2,4.8)^T = (1.6 - 0.32,2.4 - 0.48)^T =(1.28,1.92)^T

 

少しずつ最適解である(0,0)^Tに近づいています。いつか最適解にたどり着きそうなのは想像が着きますね。

 

 

学習が停滞する原因

上の例題は、紙に書いて進めていけば分かる通り、進む毎に勾配\frac {\partial L}{\partial {\bf w}}が小さくなっていきます。

 

ということは単純に考えて、学習が停滞するのは、損失関数が現在のパラメータの値においてかなり平坦な状態であるからと言えます。

 

勾配法の落とし穴

解に近づくと、平らな場所に行き着くため、結果的に勾配は小さくなり、学習は停滞し始めます。では「学習が停滞し始める=解がすぐそこに迫っている」とも言えるのでしょうか。

先ほどの例題ならば答えは「yes」です。微分して0になるような場所は一箇所しか無いため、「平坦な場所=解」であると結論づけられるためです。通常このような最適化問題は「凸最適化問題」と言います。そうでない場合は「非凸最適化問題」と呼ばれます。

 

通常ニューラルネットワークの損失関数というのは「非凸」です。微分して0になるような点がいっぱいあります。例えば

 

y = 2x^4 - 8x^3 -26x^2 +x - 5のような関数をパラメータxについて勾配法で初期値をx=-3として解いていくと、きっと以下のように学習が終わります。青の点において、勾配は0なので更新はされません。もう少し進めばもっと低い値があるのに、勾配法ではそれは得られないのです。

 

f:id:s0sem0y:20161111211801p:plain

 

最適解でなくとも平坦な場所においては、勾配は次第に小さくなり、そしてやがて止まってしまいます。このようなケースを、局所最適解に陥ったとか、ローカルミニマムに嵌ったとか言います。今回の簡単なケースならば、初期値を色々変えてみることで一応回避は可能です。例えば初期値がx = 2ならばきっと最適解に辿り着くでしょう。

 

今は一次元で見ましたが、二次元の場合はローカルミニマム以外にも勾配が0になるような場面があります。「鞍点」と呼ばれる場所です。以下の図の赤い点が鞍点です。

 

f:id:s0sem0y:20161111213342p:plain

wikipedia [saddle point]

 

このような場所でも学習は止まってしまいます。赤い点にピッタリ行き着かないにしても、この周辺では非常に勾配が小さいため、学習は中々進まなくなるのです。

 

学習が停滞し、再び進む原因

ニューラルネットの学習が停滞し、再び進むという現象は、この鞍点に囚われているところから何とか抜けだした瞬間に訪れるのです。

 

通常ローカルミニマムに捕まった場合は、特殊なことをしない限り勾配法では抜け出すことができません。なぜなら、どこに進もうとしても坂を登らなければならないからです。坂を登ることを許容してしまえば一見問題は解決しそうですが、するとローカルミニマムでないような場所で、わざわざ最適解から遠ざかる動きをする可能性もあるためリスクも伴います。しかし、鞍点ならばどうでしょう。少し移動すれば勾配がしっかり下る方向に存在しているため、少し揺さぶってやれば学習は進むはずです。

 

ニューラルネットで勾配の計算を工夫する様々な手法があるのはご存知だと思います。

AdaGradやAdaDelta、Adamなどはまさしく、この鞍点から素早く抜け出すための工夫が施された手法なのです。

 

鞍点はどこから現れるのか

ニューラルネットには、学習を停滞させる鞍点が非常に多く存在します。

もう少し鞍点のことについて詳しく整理してみましょう。

 

勾配が「0」になる点

これが学習を停滞させる原因で、鞍点の周辺も基本的には勾配が非常に小さくなります。

ただし、鞍点はローカルミニマムと異なって少し移動すれば小さいながらも損失関数を減少させる勾配は存在するため、一見学習が進むように思われます。しかし、実際にはニューラルネットがとても長い時間鞍点に囚われ続けることがかなり多いです。

なぜでしょうか。勾配が0になる点というのをもう少し詳しく考えましょう。

 

パラメータを少し変更しても、損失関数がまったく変化しない点

勾配が0というのは、言い換えるとパラメータを少し変更しても損失関数がまったく変化しないということです。微分の定義に戻ってみれば当たり前といえば当たり前でしょう。

実はこの考え方が大事です。なぜなら、鞍点が実は「1つの点」ではない可能性があるからです。鞍点が非常に多く連なり、結果的に広大な線あるいは面を形成している可能性があるのです。

すなわち少し移動しても、そこはやはり鞍点で、ほとんど勾配が得られないという状況がありうるのです。

 

簡単な例

とっても退屈なニューラルネットワークを考えましょう。

1入力1出力で、隠れユニットが1個のニューラルネットワークです。

 

f:id:s0sem0y:20161111215549p:plain

 

活性化関数が中間層でf_1、出力層でf_2とします。

するとこのニューラルネットワークは

 

y = f_2(vf_1(wx))

 

という変換を行います。(バイアス項はメンドウなので0としています)

さて、ニューラルネットワークの出力をパラメータ{\bf θ}=(v,w)^Tを色々変えて最適化してやろうと試みるわけです。損失関数は目標値tを使って

 

L({\bf θ}) = \sum_n (t - y_n)^2 = \sum_n (t- f_2(vf_1(wx_n)))^2

 

などと設定されるわけですが、v,wを調整して損失関数を減少させるというのは、結局出力yv,wを調整して上手いことコントロールしてやりたいということです。すなわち、損失関数を減少させるためには、v,wを変更した結果、出力yの値も変更を受けなければなりません。

 

v,wは色々な値を取れます。色々変更してやりたいのです。

仮にv = 0だったとしましょう。するとwがどれだけ調整を施されても出力は一切変わりません。ずっとy=0です。従って損失関数も一切変化しません。

wを変化させても損失関数に変化が起こらない。これは損失関数のwに関する勾配が0ということです。w=0.1,0.2,0.3,0.4,......,1,.....,10000と変えてもv = 0の領域で勾配が0なのです

 

これはあまりにもつまらない例題です。しかし、出力を変化させないパラメータの領域が線となって連なっている例となっています。

 

では入力多次元で\bf xとして中間層が2つ、出力が一次元yとしましょう。

 

f:id:s0sem0y:20161111221621p:plain

 

この場合の出力は

 

y = f_2(v_1f_1({\bf w_1x})+v_2f_1({\bf w_2x}))

 

となります。f_2は何らかの変換をする関数ですが、問題は中身なので個々に着目しましょう。1つは先ほどと同じようにv_1=0という場合には、\bf w_1がどんなふうに値を変えても出力に影響を及ぼしません。すなわち、v_1=0において\bf w_1による損失関数の勾配は0です。仮にw_1が二次元のベクトル(すなわち入力が二次元)ならば、パラメータ空間に面上の鞍点が生ずることになります。

 

次に、\bf w_1 = w_2の場合、中間層のユニットはいずれも同じ値を取ることになります。

従って、v_1とv_2の値によってのみ出力が変更されるのですが、仮にv_1+v_2=constであるように値を変更したとしても出力は一切変更を受けません。例えばv_1=2,v_2=3v_1=8,v_2 =-3はまったく同じ出力しか出せないのです。よってこのような場所に鞍点が生じます。

 

通常のニューラルネットワークだとどうなるか、想像に硬くないでしょう。このような関係が至るところで現れるために、実はパラメータ空間は鞍点だらけなのです。

しかも今回紹介したのは、単なる鞍点ではありません。空間的に広がりを持っていますから、少し動かしてあげた程度では、そこから抜け出すのは困難かもしれないのです。

 

特異点

先ほどまで「鞍点」と表現していましたが、実は上記のように現れる「パラメータを変更しても関数の出力が変化しない領域」を「特異点」と言います。

通常「鞍点」とは勾配が0になるが、少し移動すれば上昇も下降もできるような点のことを言います。従って特異点の方が圧倒的に質が悪いのです(少しの移動じゃ平坦な場所から抜け出せない)。

なぜ特異「点」と表現するのかというと、空間的に広がっているそのような領域は、関数にとってはまったく同一のものであるため、それを点とみなして圧縮してしまう方が便利だからです。

 

これより詳しい議論は

 

別冊数理科学 情報幾何学の新展開 2014年 08月号 [雑誌]

別冊数理科学 情報幾何学の新展開 2014年 08月号 [雑誌]

 

 

の第14章「学習の力学と特異点」に述べられています。

 

まとめ

・学習の停滞と再開は鞍点によって生ずる

・鞍点は勾配が0になる点

・勾配が0になる点⇛パラメータを少し変更しても出力にまったく変化を及ぼさない点

・ニューラルネットには特異点という質の悪い領域が広がっている

 

今回は鞍点から話をしていきました。

通常の鞍点自体ももちろん学習には悪い影響を与えており、鞍点の付近では勾配はやはり小さいため、停滞は免れません。ここから早いとこ抜け出す方法を様々な最適化手法が提供してくれています。それらについては

 

s0sem0y.hatenablog.com

 

ここでの紹介が大変参考になります。