はじめに
前回は下記の記事で学習率固定で勾配法を実施しました。
今回はウルフ条件を満たすような学習率を各更新時にバックステップで探索し、満たすものを見つけたら直ちにその学習率の更新するという形式で勾配法を実施します。
この記事ではJaxとPyTorchで収束までのステップ数や収束先等の結果はほぼ一致しましたが、速度が圧倒的にJaxの方が速く、PyTorchの計算グラフが変なふうになってしまっている可能性があります(こんなPyTorch遅いわけがない…!)
どなたか見つけたら教えて下さい…。
モジュールインポート
import jax import jax.numpy as jnp from jax import value_and_grad, jit, grad import torch import torch.autograd as ag import matplotlib.pyplot as plt import numpy as np
最適化問題はこんなもの
def objective(x): f = x[0]**4 + 2*x[0]**3 + 3*x[1]**2 + 2*x[0]*x[1] - x[1] return f x1 = np.linspace(-4, 4) x2 = np.linspace(-4, 4) xx1, xx2 = np.meshgrid(x1, x2) f = objective([xx1, xx2]) x_init = jnp.array([3., 3.]) f_init = objective(x_init)
Jax
勾配関数と線形探索関数を準備
勾配と評価値を得る関数は、value_and_grad
に目的関数を渡せば良いだけ。自動微分バンザイです。
繰り返しiterationの中で利用されるpure関数なら何でもjit
しておけば良いと思います。(使うのが一回とかだと返ってコンパイルがネックになるかも)
線形探索の方は学習率 a
を引数に取る関数として書きます。これを自動微分関数を生成する grad
に渡せば導関数が得られます。あとは探索を続ける条件を関数 cond(a)
として書いておき、探索を続ける場合の処理も body(a)
として書いておけば、学習率の初期値a_init
を予め設定し jax.lax.while_loop(cond, body, a)
に渡せば body
が True
を返す限りbody
を繰り返し続けます。
jit
前提ではPythonの構文は書かないでjax.lax
の制御関数を利用しましょう。コンパイルの時間が全く異なります。
v_and_g = jit(value_and_grad(objective)) @jit def linear_search(x, dx, tau1=0.5, tau2=0.8): a_init = 1.0 beta = 0.999 obj_a = lambda a: objective(x + dx*a) grad_obj_a = grad(obj_a) def armijo_cond(a): return obj_a(a) > obj_a(0.0) + tau1*grad_obj_a(0.0)*a def wolfe_cond(a): cond1 = armijo_cond(a) cond2 = grad_obj_a(a) < tau2*grad_obj_a(0.0) return cond1 + cond2 body = lambda a: a*beta a = jax.lax.while_loop(wolfe_cond, body, a_init) return a
最適化実行
あとは普通の勾配法と同じです。学習率 a
だけ各ステップに自動で決定されます。
x = jnp.array([3.0, 3.0]) tmp_val = 1e5 x_list = [x] a_list = [] while True: val, grad_val = v_and_g(x) a = linear_search(x, -grad_val) x = x - a*grad_val a_list.append(a) x_list.append(x) if jnp.abs(val - tmp_val).sum() < 1e-3: break tmp_val = val
ちなみに10ステップで完了し、学習率の遷移はこんな感じです。ちなみに a = 1e-3
とかで決め打ちすると3桁ステップ掛かりました。しかも局所解に捕まる始末でございます。線形探索、単純だけど強力なのね。ただ、今回利用しているウルフ条件は、バックステップ法で条件を満たすものが見つかるとも限らなさそうで、実際やっていることは怪しい。もっと効率は悪いがアルミホ条件で妥協するほうが安全ではありそうです。
plt.plot(a_list)
plt.figure(figsize=(7,7)) plt.contourf(xx1, xx2, f, cmap='Blues') for i, x_trj in enumerate(x_list): plt.scatter(x_trj[0], x_trj[1], c="r", s=80) if i == len(x_list) - 1: plt.scatter(x_trj[0], x_trj[1], c="g", s=80) plt.title(f" iter: {len(x_list)} solution : {val}")
PyTorch
線形探索関数準備
はい、torch.nn.Module
様とtorch.nn
様、及び PyTorch lightning 様に管理され、ほとんどネットワークを繋いでいるだけの使い方しかしていないため、細かい計算グラフの切り方など間違えている可能性があります。ご指摘願います。
Twitterにてご指摘をいただき a = a*beta
を with torch.no_grad()
コンテキスト内に収めたら3倍高速化しました。
def linear_search(x, dx, tau1=0.5, tau2=0.8): beta = 0.999 init_a = 1. obj_a = lambda a: objective(x + dx*a) zero = torch.zeros([], requires_grad=True) a = torch.tensor(init_a) obj_0_val = obj_a(zero) grad_obj_0_val = ag.grad(obj_0_val, zero)[0] def cond(a): a.requires_grad_() obj_a_val = obj_a(a) grad_obj_a_val = ag.grad(obj_a_val, a)[0] cond1 = obj_a_val <= obj_0_val + tau1*grad_obj_0_val*a cond2 = grad_obj_a_val >= tau2*grad_obj_0_val a.detach() return cond1 and cond2 while True: if cond(a): break with torch.no_grad(): a = a*beta return a.item()
最適化実行
x = torch.tensor([3.0, 3.0], requires_grad=True) tmp_val = 1e5 x_list = [x.detach().numpy()] a_list = [] while True: val = objective(x) grad_val = ag.grad(val, x)[0] a = linear_search(x.clone().detach(), -grad_val.clone().detach()) x = x - a*grad_val x_list.append(x.detach().numpy()) a_list.append(a) if torch.abs(val - tmp_val).item() < 1e-3: break tmp_val = val
可視化は同じなので結果だけ貼って省略します。
結果
google colabで
Jax 0.365 sec
PyTorch 21 sec -> 7.8 sec
絶対なんかおかしい。Jax非同期ディスパッチでPythonでの計測がオカシイにしても、実際体感でJaxは1秒以内、PyTorchは20秒は掛かってた…。 PyTorch修正後 8秒未満まで短縮。 Jaxは jax.lax.while_loop
を利用せず Pythonでループを回すと同様に8秒程掛かったことから、jax
の制御式の実装がエゲツナク速い模様です(確かにiteration1回1回はかなり軽いので、条件判定等が最もボトルネックだったかもしれないことを考えると、学習ループの中に更に探索ループがいる場合、そこが利いてくるのはそのとおりかも)。