はじめに
PyTorchとjaxの比較用。この手のライブラリを使うには、autogradの使い方を理解することが一番最初の仕事だと思われます。そして、そのautogradの時点で大きく思想が異なっているので、メモしておきます。
使う関数
下記をインポートしている前提で
import torch import jax import jax.numpy as np from jax import vmap, jit, grad
下記の2次元上のスカラー関数
$$ f(x, y) = x ^ 2 - y ^ 2 + 2 x y $$
を微分していきます。
def f(x, y): return x**2 - y**2 + 2*x*y
と書いておきます。
autograd with pytorch
偏微分はそれぞれ
$$ \partial _ x f(x, y) = 2x + 2y \\ \partial _ y f(x, y) = -2y + 2x $$
なので、$x = y = 1$ を与えれば
$$ \partial _ x f(1, 1) = 4 \\ \partial _ y f(1, 1) = 0 \\ $$
となっているはずです。
# 具体的な値を準備 x = torch.tensor(1.0).requires_grad_() y = torch.tensor(1.0).requires_grad_() # 具体的な関数の戻り値を得る f_val = f(x, y) # 関数の戻り値を第一引数に、微分したい変数をタプルでまとめて第二引数に # create_graph=True にしておかないと grad_f を後で続けて微分できない grad_f = torch.autograd.grad(f_val, (x, y), create_graph=True) # 戻り値の第一引数が x による偏微分値 dfdx = grad_f[0] # 戻り値の第二引数が y による偏微分値 dfdy = grad_f[1] print(dfdx) print(dfdy) # tensor(4., grad_fn=<AddBackward0>) # tensor(0., grad_fn=<AddBackward0>)
正しく計算されています。
続けて二階微分を求めてみます。二次関数なので2階微分は定数です。
$$ \partial _ {xx} f(x, y) = 2 \\ \partial _ {xy} f(x, y) = 2 \\ \partial _ {yx} f(x, y) = 2 \\ \partial _ {yy} f(x, y) = -2 $$
となっているはずです(たちの悪い関数ではないので $\partial _ {xy} = \partial _ {yx}$ となっている。$\partial _ {xx} = \partial _ {yx}$ はただの偶然)。
grad_dfdx = torch.autograd.grad(dfdx, (x, y)) grad_dfdy = torch.autograd.grad(dfdy, (x, y)) dfdx2 = grad_dfdx[0] dfdxy = grad_dfdx[1] dfdyx = grad_dfdy[0] dfdy2 = grad_dfdy[1] print(dfdx2) print(dfdxy) print(dfdyx) print(dfdy2) # tensor(2.) # tensor(2.) # tensor(2.) # tensor(-2.)
問題ありません。具体的な値と、その計算に関わった具体的な変数を入れてあげれば微分をその偏微分値をタプルで返してくれるというシンプルな構成のため、割と直感的です。ちなみにここで求まった4つの偏微分値は、ヘッシアンの各成分に対応するので
torch.autograd.functional.hessian(f, (x, y))
# ((tensor(2.), tensor(2.)), (tensor(2.), tensor(-2.)))
によって求まります。なんでこちらは、第一引数を具体的な値ではなく関数で与え、第二引数を具体的な値で渡すのかは不明(もともと具体的な値を渡すようにしていたのは、Tensor自身が履歴を持っているという仕組みだったので、関数を渡す必要などなかったのであった)。
autograd with jax
関数 f
は使いまわします。こちらは、具体的な値は最後まで与えないというのがコーディングの指針になります。できる限り関数を書いていくスタイルです。 grad
関数に関数 f
と、関数 f
の何番目の引数に関する微分なのかを指定し、grad
関数がf
の計算グラフを構築します。
すなわちPyTorchとは真逆で、Tensor自身に多くの情報を持たせません(例えばPyTorchでは計算の履歴やgradの値を各Tensor自体が持っていた)。
# 具体的な値は当面使わない x = np.array(1.0) y = np.array(1.0) # 関数 f を grad に渡すと、勾配を計算する関数が生成される grad_f = grad(f, argnums=[0, 1]) # grad_fに(x, y) を渡すと返り値に関して dfdx と dfdy はそれぞれ第一要素と第二要素にある。 # それぞれの値を返す関数を作っておく(高階微分に必要) dfdx = lambda x, y: grad_f(x, y)[0] dfdy = lambda x, y: grad_f(x, y)[1] # ここで初めて具体的な値が入る print(dfdx(x, y)) print(dfdy(x, y)) #4.0 #0.0
若干手間があるようにも思いますが、設計しているのが常に関数で、最後の最後の欲しいものを流すという感覚はTensorFlowのv1を若干思い起こすところがあります。とは言いつつ、普通にPythonとして動くという点で全く書きやすさは異なります(なので、別に関数を作っていく…という方針じゃないコーディングもできる)。
次に高階微分は下記のようになります。
# dfdx, dfdy という関数が準備されているのでgradにその関数を渡してあげれば良い grad_dfdx = grad(dfdx, argnums=[0, 1]) grad_dfdy = grad(dfdy, argnums=[0, 1]) # 再び各成分を取り出す関数を書いているが、これ以上微分しないならば必要はない。 # 一応形式的に書いておく。 # これは関数fのヘッシアンの各成分を返す関数を書いていることに相当する。 dfdx2 = lambda x, y: grad_dfdx(x, y)[0] dfdxy = lambda x, y: grad_dfdx(x, y)[1] dfdyx = lambda x, y: grad_dfdy(x, y)[0] dfdy2 = lambda x, y: grad_dfdy(x, y)[1] print(dfdx2(x, y)) print(dfdxy(x, y)) print(dfdyx(x, y)) print(dfdy2(x, y))
これはヘッシアンの各成分なので
jax.hessian(f, [0, 1])(x, y) #((DeviceArray(2., dtype=float32), DeviceArray(2., dtype=float32)), # (DeviceArray(2., dtype=float32), DeviceArray(-2., dtype=float32)))
で求まります。
Jax で単回帰
正直お作法はよく分かってないですが
def model(params, x): return params["a"] * x + params["b"] @jit def loss_fn(params, x, y): y_pre = model(params, x) return np.power(y_pre - y, 2).mean() grad_loss = grad(loss_fn, argnums=[0]) def optimize(params, grads, lr=1e-3): return {kp: vp - lr*vg for (kp, vp), (_, vg) in zip(params.items(), grads[0].items())} # # x, y = get_traindata() # params = {"a": np.array(1.), "b": np.array(0.)} # 標準的な考えだと下記のようにするところ #for n in range(10000): # grads = grad_loss(params, x, y) # params = optimize(params, grads) # 上記のループを関数で書いておく @jit def train(epoch, x, y, params): def body_fun(idx, params): grads = grad_loss(params, x, y) params = optimize(params, grads) return params params = jax.lax.fori_loop(0, epoch, body_fun, params) return params params = train(10000, x, y, params)
という感じになりそう。
ちなみにjax.lax.fori_loop
使うとのと使わないのとでは全然スピードが違いました。
%%timeit params = {"a": np.array(1.), "b": np.array(0.)} params = train(10000, x, y, params) # 1000 loops, best of 3: 1.75 ms per loop
%%timeit params = {"a": np.array(1.), "b": np.array(0.)} for _ in range(10000): grads = grad_loss(params, x, y) params = optimize(params, grads) # 1 loop, best of 3: 25.9 s per loop
いや差が凄い!!すごすぎる…。なんだこれ。jit 速すぎ、普通の遅すぎ!!でございました。 PyTorchの方は
%%timeit # parameter に関する勾配が必要なので requires_grad=True とする a = torch.randn(1, requires_grad=True) b = torch.randn(1, requires_grad=True) for i in range(10000): zero_grad(a, b) y_pred = model(x, a, b) loss_value = loss(y_pred, y) # 計算に使われた a, b のgradに値が格納される loss_value.backward() update(a, b, lr=1e-2) # 1 loop, best of 3: 2.81 s per loop
まあ遅くはないけど、Jax.jit何が起きてんですか!!?何かの間違いか??(CPUでやってる限りこんなもんか?でかいモデルをGPUで扱うようになるとボトルネックが変わるか。それでも凄い)
損失の計算や最適化の処理も純粋関数で書くことにすると、PyTorchみたいに各インスタンスの内部変数が他のインスタンスに書き換えられたりするような仕組みとは異なり、完全に入出力だけを見ておけば他の何かが起こった可能性を一切省けます(何か不具合があったとすると、その関数の中に犯人がいると確定する。たぶん。。。)。なにより jax.jit
は関数をコンパイルしてくれる機能なので、クリーンな関数を書いておけば高速化で大いに活躍するjit
周りで痛い目見なくて済みそうです。PyTorchのjit地味に辛いからね…。
x_test = np.linspace(-3, 3, 1000) y_predict = model(params, x_test) plt.plot(x, y, "o") plt.plot(x_test, y_predict)