モジュール
import jax.numpy as np import jax from jax import random, grad, vmap, jit, lax import matplotlib.pyplot as plt import seaborn as sns import numpyro import numpyro.distributions as dist from numpyro import plate, sample, handlers from numpyro.infer import MCMC, NUTS, SVI, ELBO plt.style.use("seaborn") key = random.PRNGKey(1)
データ
今回はsin波が指数減衰していくような関数を扱うことにします。
def toy_data(N, key=key): x = random.uniform(key, shape=(N, 1), minval=-2, maxval=2) y = np.sin(7*x) * np.exp(-0.5*x) + 0.5*onp.random.randn(N, 1) return x, y.squeeze() x, y = toy_data(100)
ガウス過程
ガウス過程はハイパーパラメータの推論を除けばすべて解析的に求まるため、jax.numpy
を使って計算をすべて書いてしまいます。
カーネル関数
カーネル関数にはRBFを利用しますが、観測分散も一緒に取り込んでいる形式とそうでない形式の両方を準備します。後に予測を行う際、取り込んでいない形式を非対角のブロック行列の計算に用います。
@jit def Kernel(X, Z, var, length): # distance between each rows dist_matrix = np.sum(np.square(X), axis=1).reshape(-1, 1)\ + np.sum(np.square(Z), axis=1)\ - 2 * np.dot(X, Z.T) return var * np.exp(-0.5 / length * dist_matrix) @jit def Kernel_with_noise(X, Z, var, length, noise, jitter=1e-6): return Kernel(X, Z, var, length) + np.eye(X.shape[0]) * (noise + jitter)
ちなみに、最初は素朴にグラム行列をfor文で書いたのですが、jitが異様に遅く使い物になりませんでした(この辺りはTensorFlowのtf.function
の凄みを感じた?)。その後、下記のjax.lax.fori_loop
を利用しましたが、jitは良いとして自動微分に未対応のためハイパーパラメータの推論のときに詰まったためボツとしました。一応コードは下記のようになります。
# fori_loopが自動微分未対応? # @jit # def Kernel(X, Z, var, length): # I = X.shape[0] # J = Z.shape[0] # K = np.zeros(shape=[I, J]) # def body(i, K): # def inner_body(j, K): # k_xz = rbf_kernel(X[i], Z[j], var, length) # K = jax.ops.index_update( # K, jax.ops.index[i, j], k_xz # ) # return K # K = lax.fori_loop(0, J, inner_body, K) # return K # K = lax.fori_loop(0, I, body, K) # return K
予測
グラム行列の計算手段があれば、あとは入力訓練データ、出力訓練データ、入力テストデータを用いて予測を直接書き下すことができます。これはパラメータの周辺化消去とカーネルトリックによって、パラメータが計算上から消え、通常の回帰モデルにおける学習結果と言えるものをデータで直接表現できるということです。
通常は、K_xx_inv
の逆行列計算($O(n ^ 3)$)に多大なコストを要するため、様々な計算方法が提案されています。ここでは素朴にすべてのデータ点を用いて、逆行列を実直に計算する実装となっています。
def predict(rng_key, X, Y, X_test, var, length, noise): k_pp = Kernel_with_noise(X_test, X_test, var, length, noise) k_pX = Kernel(X_test, X, var, length) k_XX = Kernel_with_noise(X, X, var, length, noise) K_xx_inv = np.linalg.inv(k_XX) K = k_pp - np.matmul(k_pX, np.matmul(K_xx_inv, np.transpose(k_pX))) sigma_noise = np.sqrt(np.clip(np.diag(K), a_min=0.)) * jax.random.normal(rng_key, X_test.shape[:1]) mean = np.matmul(k_pX, np.matmul(K_xx_inv, Y)) return mean, mean + sigma_noise
決め打ちハイパーパラメータでの予測
データ点が2つでもとりあえずガウス過程は予測を出せます。今回はひとまず予測の平均値だけをプロットしてみます。
次第にデータ点を増やしていきましょう。
最終段階では真の関数も一緒に表示します。データ点自体が真の関数から誤差を持って観測されることに加え、ハイパーパラメータが決め打ちであるために、上手く予測ができていないようにも見えます。
ガウス過程はデータ点を追加することで、いわゆる普通の回帰モデルにおける学習に相当する結果を直ちに得られます。一方でハイパーパラメータを調整するというのは、普通の回帰モデルで言うと多項式の次数を調整したり非線形の項を入れてみたりと、モデルそのものを調整する作業に相当します(なぜなら、ガウス過程のハイパーパラメータは関数がどれくらい曲がれるのに関わるからである)。
MCMC でのハイパーパラメータ推論
モデル
モデルは非常にシンプルに書くことができます。既にガウス過程における重要な計算は実装しているためです。やらなければならないことは、ハイパーパラメータに事前分布を与えてやることと、ガウス過程の定義である「どのような観測点のセットを選んでも、その観測点らが(データ数の次元での)多変量正規分布に従う」という形式でサンプリングを行うことです。
このときの共分散行列が、入力データ点によるカーネルグラム行列で記述されるのでした。
def model(X, Y): var = sample("kernel_var", dist.LogNormal(0.0, 10.0)) noise = sample("kernel_noise", dist.LogNormal(0.0, 10.0)) length = sample("kernel_length", dist.LogNormal(0.0, 10.0)) K = Kernel_with_noise(X, X, var, length, noise) sample( "Y", dist.MultivariateNormal(loc=np.zeros(X.shape[0]), covariance_matrix=K), obs=Y)
事前分布からのサンプリング
事前分布から選ばれたハイパーパラメータによってサンプリングを実施してみます。 これはよもや決め打ちのハイパーパラメータよりもいい加減な結果になることでしょう(事前分布の分散を見よ)。
trace_model = handlers.trace(handlers.seed(model, key)) prior_Y = trace_model.get_trace(random.normal(key, shape=(500, 1)), None)["Y"] plt.plot(prior_Y["value"], "o")
事後分布の推論
NumPyroでのMCMCは非常に明快なAPIとなっています。 今回のハイパーパラメータはすべて正に値を取るべき変数たちであるので、MCMCの内部で負の値がサンプリングされてしまうと不具合が生じます。そのため、通常では適当な変数変換によってMCMCの空間では実数全体を探索させておき、ハイパーパラメータとして使うときには適切な制約された空間に収まるように仕立て上げます。
NumPyroにおいてそれらの処理は、事前分布の定義域から自動で実施してくれるため特にユーザーが行う設定はありません。(TensorFlow Probabilityではtfp.bijectors
を用いて各々パラメータに対して適切な変換を準備する必要があります。)
def run_inference(model, warm_up, samples, key, X, Y): kernel = NUTS(model) mcmc = MCMC(kernel, warm_up, samples, progress_bar=True) mcmc.run(key, X, Y) mcmc.print_summary() return mcmc.get_samples() samples = run_inference(model, 500, 1000, key, x_data, y_data) sample: 100%|██████████| 1500/1500 [00:13<00:00, 113.48it/s, 11 steps of size 4.40e-01. acc. prob=0.93] mean std median 5.0% 95.0% n_eff r_hat kernel_length 0.06 0.03 0.06 0.02 0.10 402.01 1.00 kernel_noise 0.24 0.06 0.24 0.15 0.33 719.51 1.00 kernel_var 3.85 3.57 2.84 0.82 7.20 386.16 1.00 Number of divergences: 0
plt.figure(figsize=(15, 4)) plt.subplot(131) sns.distplot(samples["kernel_length"]) plt.title("length") plt.subplot(132) sns.distplot(samples["kernel_var"]) plt.title("var") plt.subplot(133) sns.distplot(samples["kernel_noise"]) plt.title("noise")
予測分布
予測分布を出すためには、上記で得た事後分布(からのサンプリング)を用います。 今回は事後分布からのサンプリングを1000点準備しているので、ある1つの $x$ に対して $y$ が1000個計算できることになります。その結果から平均と分散を用いて簡易的に予測分布を表示することができます。
今回は vmap
を用いて、既に実装されているpredict
関数をvectorizeします。
各々の予測において異なる乱数シードを用いるように、乱数シードも1000個準備することとします。
rng_key_predict = random.split(random.PRNGKey(0), num=1000)
predict
関数における、「バッチ処理が必要な変数」のみが引数となるように lambda
式で関数を作ります。
predict_samples = vmap(
lambda rng_key, var, length, noise:
predict(rng_key, x_data, y_data, x_test, var, length, noise)
)
準備が整いました。これで予測分布を計算してみます。
mean, mean_noise = predict_samples( rng_key_predict, samples["kernel_var"], samples["kernel_length"], samples["kernel_noise"] ) mean_mean = mean_noise.mean(0) mean_std = mean_noise.std(0) plt.figure(figsize=(15, 6)) plt.plot(x_data, y_data, "o", alpha=0.5) plt.plot(x_test, mean_mean, "b") plt.fill_between(x_test.squeeze(), mean_mean-3*mean_std, mean_mean+3*mean_std, alpha=0.2) plt.plot(x_true, y_true, "g")
ハイパーパラメータの推論を実施したことで、決め打ちのときより遥かに良い予測ができているようです。