はじめに
本当はTensorFlow2が世の中で使われるようになって、情報も増えるのが一番嬉しいのですが、ちょっと周囲の状況も含めてPyTorch続投の兆しが強いため、確率的プログラミング言語としてPyroを選択する可能性も出てきました。というわけでPyroの記事です。
Pyro
PyTorchをバックエンドとした確率的プログラミング言語(PPL)です。PPLの名に恥じないくらい、確率モデリングが容易に可能で、TensorFlow Probabilityほど剥き出しのTensorをアレコレ扱わなくても済みます。今回はそんなPyroを支えている premitive
と poutine
の基本的なものを触っていきます。
$ pip install pyro-ppl
で簡単にインストールできるのでPyTorchユーザーはぜひ試してみてください。 また以下のインポート文を前提とします。
import numpy as np import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.functional as F import pyro plt.style.use("seaborn")
primitives
確率変数の実現値 sample
pyro.sample
は指定した分布クラスのインスタンスから確率変数の実現値を生成する関数です。出てくるのは普通にPyTorchで使われるtensor
です。
$$ x \sim {\rm Normal}(x\mid 0, 1) $$
x = pyro.sample("x", pyro.distributions.Normal(0, 1)) print(x) # tensor(0.4913)
渡している分布のインスタンスも pyro
モジュールが提供しているものであることに注意してください(PyTorchにも torch.distributions
がある)。
Pyroでは pyro.sample
で得られた値を確率変数として自動で認識して、MCMCやVIで必要な対数確率の計算を自動で行う仕組みが備わっています。そのときには、pyro.sample("hoge", dist)
の第一引数で与えた名前が使われます。この名前を使って確率変数を認識しているので、区別すべきサンプリングにはちゃんと異なる名前を付けなければなりません。
もしも複数のサンプルを$\rm i.i.d.$ から得たい場合には、普通に for
ループを回すことで実現できます。
$$
x _ i \sim {\rm Normal} (x\mid 0, 1)
$$
$$
{\rm where}\ \ i = 1,\cdots, N
$$
N = 100 x_list = [] for _ in range(N): x = pyro.sample("x", pyro.distributions.Normal(0, 1)) x_list.append(x) plt.hist(x_list)
条件付き独立のベクトル化 plate
上記のように for
ループをPythonで回すのは効率的ではありません。また、サンプルをすべて保持するためにリストを経由するのもコストが高いでしょう。
pyro.plate
を用いると独立なサンプルをベクトル化した状態で得ることができます(独立なサンプルなのだから、順番に得ても、同時に得ても一緒である。サイコロを1回ずつN回振っても、N個のサイコロを1回振っても一緒なのと同じだ)。
with pyro.plate("samples", N): x = pyro.sample("x", pyro.distributions.Normal(0, 1))
コンテキスト内のサンプルは $\rm i.i.d.$ として扱われ、指定した個数のサンプルを同時に得ることができます。 ちなみに表題の「条件付き独立」とはどういうことかといいますと、例えば $x$ と $y$ がそれぞれ $z$ に依存してサンプリングされる下記のケース
$$
\begin{align}
x & \sim p(x \mid z) \\
y & \sim p(y \mid z)
\end{align}
$$
で、$x$ と $y$ が独立な場合には、$z$ に関して条件付き独立であるといいます。今、N個のサンプルを
$$ \begin{align} x _ 1 & \sim p(x \mid z) \\ x _ 2 & \sim p(x \mid z) \\ &\vdots \\ x _ N & \sim p(x \mid z) \end{align} $$
と得ることを考えると、これは各サンプル $x _ i$ が互いに条件付き独立であるというわけです。確率モデリングではある確率変数が他の確率変数に影響を持っているケースは多々あり、いわゆる「独立」というのは、何かに関しての条件付き独立である場合が多いです。Pyroでも
z = pyro.sample("z", pyro.distributions.Normal(0, 1)) with pyro.plate("samples", N): x = pyro.sample("x", pyro.distributions.Normal(z, 1))
などとすることで、正規分布から生成される $z$ の実現値を平均とした正規分布から生起する $x$ を条件付き独立で $N$ 個得ることができます。より具体的にはベータ事前分布を指定したベルヌーイ分布でコインの裏表をモデル化した場合には、下記のように書いたりすることになります。この例では、一度だけベルヌーイ分布の確率パラメータをベータ分布からサンプリングし、そのサンプリングされた値を使って、ベルヌーイ分布から $N$個のサンプルを得ています。
p = pyro.sample("p", pyro.distributions.Beta(1.5, 1.5)) with pyro.plate("conditional_independent_sample", N): x = pyro.sample("x", pyro.distributions.Bernoulli(p)) print(p) # tensor(0.3506) plt.hist(x)
階層モデル
これまでの例のように、Pyroで確率モデルを書く際には、データのサンプリングの様子(生成モデル)を直感的に書き下すことが可能です。 例えば以下のように(無意味な)階層モデルも、Pyroでは簡単に表現することができます。
$$ \begin{align} \mu &\sim \mathcal {\rm LogNormal} (\mu\mid 0, 10)\\ \sigma &\sim \mathcal {\rm InvGamma}(1, 1) \\ x _ i & \sim {\rm LogNormal}(x\mid \mu, \sigma)\\ y _ i & \sim {\rm LogNormal}(y\mid \mu, \sigma) \\ z _ i & \sim {\rm Gamma}(z \mid x _ i, y _ i) \\ o _ i & \sim {\rm Poisson}(o \mid z _ i) \end{align} $$
N = 100 def model(): mu = pyro.sample("mu", pyro.distributions.LogNormal(0, 10)) sigma = pyro.sample("sigma", pyro.distributions.InverseGamma(1, 1)) with pyro.plate("plate", N): x = pyro.sample("x", pyro.distributions.LogNormal(mu, sigma)) y = pyro.sample("y", pyro.distributions.LogNormal(mu, sigma)) z = pyro.sample("z", pyro.distributions.Gamma(x, y)) o = pyro.sample("o", pyro.distributions.Poisson(z)) return [mu, sigma, x, y, z, o] samples = model() print(samples) ''' [tensor(0.0002), tensor(0.7015), tensor([0.4257, 1.0627, 7.0144, 1.4176, 0.4780, 2.9072, 1.1962, 4.7277, 0.9379, 0.7680, 2.7863, 0.7451, 1.0525, 1.6884, 0.5842, 3.0398, 1.9079, 2.4766, 0.7915, 1.0864, 3.6698, 1.7241, 3.6868, 0.3621, 0.5752, 0.8354, 0.5892, 1.0063, 0.7739, 0.8256, 1.5383, 0.5119, 0.4068, 0.2762, 1.0071, 3.6454, 1.3721, 0.7472, 0.9248, 0.4739, 0.7896, 0.3628, 0.5335, 0.8317, 2.6559, 0.7650, 1.5906, 3.0238, 0.5154, 0.7295, 4.8089, 0.9203, 1.5637, 0.5140, 0.7987, 0.9642, 1.4488, 1.3863, 1.4500, 1.3132, 0.5954, 0.7617, 1.1997, 1.8335, 3.1235, 0.4803, 0.4280, 3.4608, 1.3163, 0.4663, 1.3383, 1.1243, 1.9121, 1.3954, 1.7118, 0.9123, 0.6367, 0.2152, 0.6787, 0.4656, 0.9141, 1.6995, 1.2588, 0.5432, 1.5837, 0.1930, 0.6416, 1.1732, 0.9614, 0.9492, 0.3509, 0.8111, 1.2070, 0.7340, 2.9652, 0.9011, 0.7872, 1.1136, 0.3708, 1.4184]), tensor([1.2443, 0.7041, 1.2711, 1.0473, 1.1280, 1.1641, 1.4434, 0.9209, 2.8864, 0.6230, 1.6070, 0.5382, 1.5361, 1.5297, 0.7564, 4.8376, 1.4728, 0.7040, 0.4175, 0.9239, 0.4513, 0.4707, 0.5097, 0.6827, 1.9157, 2.1517, 0.2790, 0.4307, 1.7413, 1.6795, 0.7066, 1.5730, 0.7974, 2.0715, 0.4106, 1.8119, 0.2777, 0.6305, 1.2698, 3.3913, 0.2719, 0.4372, 0.5516, 1.2283, 0.1913, 1.8327, 0.7662, 1.5114, 1.5137, 0.2557, 1.5318, 1.0386, 0.3560, 1.6855, 0.7154, 0.5314, 2.0687, 1.1398, 1.0672, 0.3195, 0.9252, 1.4356, 0.3071, 1.5157, 0.6100, 0.5258, 0.7779, 3.1353, 1.0920, 0.1867, 1.5206, 0.8126, 0.4158, 4.2453, 1.4318, 0.5912, 1.2107, 1.9712, 1.0595, 0.8044, 1.6302, 1.9796, 0.8921, 0.7125, 1.3348, 0.5718, 0.2341, 1.5576, 2.4383, 0.4278, 2.5422, 0.6365, 1.2151, 2.5433, 2.4934, 0.8783, 1.5873, 0.9414, 4.7714, 0.7473]), tensor([4.3462e-02, 2.3453e+00, 3.5184e+00, 1.6382e+00, 2.9499e-01, 9.1393e-01, 3.8286e+00, 6.5808e+00, 4.5203e-01, 7.8937e-01, 1.4508e+00, 4.4677e-01, 4.3575e-02, 6.4304e-01, 8.7083e-02, 1.1236e+00, 2.1808e+00, 2.1357e+00, 2.2062e+00, 1.7447e+00, 1.1076e+01, 2.3367e+00, 4.1230e+00, 1.2337e-02, 2.5558e-01, 5.1439e-01, 6.8960e-01, 1.2596e+00, 4.0812e-02, 2.5280e-01, 9.9072e-01, 2.2785e+00, 1.6162e-01, 2.8976e-10, 1.1553e+01, 2.2747e+00, 1.3333e+01, 3.3048e+00, 4.4796e-01, 1.6227e-02, 4.3604e+00, 3.6324e-01, 2.5430e+00, 8.2950e-01, 3.5135e+01, 1.3029e+00, 4.1930e+00, 2.5822e-01, 4.0985e-01, 3.9308e-01, 3.8958e+00, 2.3114e+00, 5.4542e+00, 3.9611e-01, 3.3914e+00, 4.5571e+00, 1.5696e+00, 1.2822e+00, 1.1186e+00, 1.0413e+01, 1.8252e+00, 1.5227e+00, 7.1146e+00, 2.7419e+00, 5.5266e+00, 1.2939e-01, 7.2556e-01, 9.9216e-01, 1.4511e+00, 8.0014e+00, 7.5821e-01, 1.9351e-01, 4.3107e+00, 3.0771e-01, 2.1022e+00, 8.1710e-01, 5.0636e-01, 1.3115e-01, 4.0444e-02, 6.6399e-01, 3.8530e-01, 1.1365e+00, 8.6651e-01, 2.2942e-01, 2.7417e-01, 5.7365e+00, 6.3949e-02, 2.0346e-01, 3.1939e-01, 1.6593e+00, 1.9173e-03, 1.0179e-01, 9.9543e-01, 6.1918e-02, 9.3573e-01, 7.3790e-02, 6.9964e-02, 1.7703e-01, 8.3770e-04, 3.1362e+00]), tensor([ 0., 4., 2., 4., 0., 0., 5., 10., 1., 0., 1., 0., 0., 1., 0., 1., 2., 0., 4., 2., 10., 1., 5., 0., 0., 1., 0., 4., 0., 0., 0., 2., 0., 0., 9., 4., 13., 2., 0., 0., 9., 0., 1., 1., 36., 1., 4., 0., 0., 0., 3., 1., 5., 0., 5., 7., 0., 0., 1., 13., 2., 2., 7., 3., 5., 1., 1., 0., 1., 6., 1., 0., 4., 0., 3., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 5., 0., 0., 1., 4., 0., 0., 2., 0., 0., 0., 1., 0., 0., 1.])] '''
確かに pyro.plate
の外の変数はサンプルが1個で、内部の変数は(条件付き独立で)指定した個数得られているようです。
変分パラメータを扱う param
さて、確率モデルの記述は割と簡単にできそうであります。一方で我々が実際にPPLを使うときには、データから分布を推論したいケースがほとんどでしょう。単にサンプリングをして遊びたいのではありません。
Pyroではパラメータをグローバルで管理する仕組みとなっており(これは、なんというか、如何なものかと思うが…)、pyro.param
で生成された値はグローバル空間に辞書のような形で登録されます。
w = pyro.param("weight", init_tensor=torch.tensor([5.], requires_grad=True)) b = pyro.param("bias", init_tensor=torch.tensor([1.], requires_grad=True)) print(w) print(b) # tensor([5.], requires_grad=True) # tensor([1.], requires_grad=True) print(pyro.get_param_store()) # <pyro.params.param_store.ParamStoreDict object at 0x7f3c4241c7f0> for param in pyro.get_param_store(): print(param) # weight # bias
ここでも、pyro.param("hoge", tensor)
の第一引数である名前を使って管理しているため、区別しなければならないパラメータには然るべき名前をつける必要があります。
pyro.param
関数に、すでに存在する名前を引数として与えるとどうなるかというと、すでに登録されているパラメータの値を tensor
として返してきます。もしも名前に被りがある状態でモデルを組んでしまった場合は、明らかに意図せぬ挙動となることに注意しましょう。
print(pyro.param("weight").item()) print(pyro.param("bias").item()) # 5.0 # 1.0
poutine
モデルの様子を把握する trace
Pyroの便利な仕組みを支えているのは pyro.poutine
と呼ばれるモジュールです。このモジュールによって確率モデルのサンプリングの様子を認識できるようになっています。
例えば、階層モデルを作って pyro.trace
に渡してやりましょう。
N = 100 def model(): mu = pyro.sample("mu", pyro.distributions.LogNormal(0, 10)) sigma = pyro.sample("sigma", pyro.distributions.InverseGamma(1, 1)) with pyro.plate("plate", N): x = pyro.sample("x", pyro.distributions.LogNormal(mu, sigma)) y = pyro.sample("y", pyro.distributions.LogNormal(mu, sigma)) z = pyro.sample("z", pyro.distributions.Gamma(x, y)) o = pyro.sample("o", pyro.distributions.Poisson(z)) tr_model = pyro.poutine.trace(model) print(tr_model) # <pyro.poutine.trace_messenger.TraceHandler at 0x7fb0b83ac860>
なんだかよくわからない TraceHandler
なるクラスが返されたようです。このあたり、Pyroの設計をしっかりと把握するまでは無理に理解する必要はないでしょう(というかちゃんとわかってない。やはりTensor剥き出しで自分で管理するTensorFlow Probabilityの方が、仕組みが気になるような人にはある意味とっつきやすいのである)。
さて、仕組みの理解はさておいて、得られたクラスのメンバ関数である get_trace()
を呼んでみましょう。
trace = tr_model.get_trace() print(trace) # <pyro.poutine.trace_struct.Trace at 0x7fb0b83aceb8>
また何やらよくわからんものが返ってきました。この Trace
クラスは、先程自分で作ったモデルの情報を把握してくれているものです。
例えば作った model()
関数の引数が何であって、戻り値が何であって、更に内部で生成されている確率変数が何なのかをすべて把握しているのです。その情報は Trace.nodes
という内部変数に記録されています。
print(trace.nodes) ''' OrderedDict([('_INPUT', {'args': (), 'kwargs': {}, 'name': '_INPUT', 'type': 'args'}), ('mu', {'args': (), 'cond_indep_stack': (), 'continuation': None, 'done': True, 'fn': LogNormal(), 'infer': {}, 'is_observed': False, 'kwargs': {}, 'mask': None, 'name': 'mu', 'scale': 1.0, 'stop': False, 'type': 'sample', 'value': tensor(1.0680e-12)}), ~~~~ 中略 ~~~~~~~~ ('o', {'args': (), 'cond_indep_stack': (CondIndepStackFrame(name='plate', dim=-1, size=100, counter=0),), 'continuation': None, 'done': True, 'fn': Poisson(rate: torch.Size([100])), 'infer': {}, 'is_observed': False, 'kwargs': {}, 'mask': None, 'name': 'o', 'scale': 1.0, 'stop': False, 'type': 'sample', 'value': tensor([7.0000e+00, 0.0000e+00, 1.4250e+03, 4.7000e+01, 0.0000e+00, 1.6200e+02, 0.0000e+00, 0.0000e+00, 1.6220e+03, 1.0000e+00, 7.0000e+00, 4.0000e+00, 7.7000e+01, 0.0000e+00, 1.7000e+01, 1.0000e+00, 6.0000e+00, 5.0700e+02, 0.0000e+00, 4.0000e+01, 0.0000e+00, 0.0000e+00, 1.7000e+01, 0.0000e+00, 1.0000e+00, 2.0000e+00, 0.0000e+00, 6.9720e+03, 0.0000e+00, 1.0000e+00, 1.0000e+00, 3.1950e+03, 6.8000e+01, 0.0000e+00, 0.0000e+00, 7.0000e+00, 0.0000e+00, 4.0000e+00, 0.0000e+00, 4.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.4000e+01, 3.0000e+00, 0.0000e+00, 1.9430e+03, 0.0000e+00, 0.0000e+00, 8.1000e+01, 0.0000e+00, 1.1300e+02, 0.0000e+00, 3.0000e+00, 1.0000e+00, 1.5124e+04, 0.0000e+00, 2.0000e+01, 0.0000e+00, 1.0000e+00, 0.0000e+00, 2.1000e+01, 3.3300e+02, 3.0500e+02, 0.0000e+00, 6.0000e+00, 4.0000e+01, 0.0000e+00, 9.0000e+00, 1.9000e+01, 2.8380e+03, 0.0000e+00, 1.7000e+01, 5.0000e+00, 0.0000e+00, 5.0000e+00, 1.4000e+01, 1.8000e+01, 5.0000e+00, 3.9000e+01, 1.5000e+01, 1.0000e+01, 0.0000e+00, 5.0000e+00, 4.0000e+00, 3.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+01])}), ('_RETURN', {'name': '_RETURN', 'type': 'return', 'value': None})]) '''
さて、このようになっていることが分かれば、trace.nodes["o"]["value"]
としてやることで、先程作ったモデルの "o"
と名付けた確率変数の実現値を覗くことができます。
実際にこのように値を取りに行くことをデータ分析者がする必要はありませんが、この設計によって、Pyroはサンプリングされた値やモデルの構造を把握できるというわけです。
条件付き分布の作成 condition
さて、今仮に、
def model(x_data, y_data): w = pyro.sample("weight", pyro.distributions.Normal(0, 10)) b = pyro.sample("bias", pyro.distributions.Normal(0, 10)) with pyro.plate("obs", N): y = pyro.sample("y", pyro.distributions.Normal(w * x_data + b, 0.5), obs=y_data) return y
というモデルを書いてみたとしましょう。これは
$$ \begin{align} w & \sim {\rm Normal} (0, 10) \\ b & \sim {\rm Normal} (0, 10) \\ \end{align} $$
という事前分布のもと、回帰モデル $y = wx + b +\epsilon$ を
$$ \epsilon \sim {\rm Normal} (0, 0.5) $$
あるいは
$$ y \sim {\rm Normal} (y \mid wx+b, 0.5) $$
とモデル化したことになります。これは $x$ が与えられたときの同時分布を
$$ p(y, w, b \mid x) = {\rm Normal} (y \mid wx+b, 0.5) {\rm Normal} (0, 10){\rm Normal} (0, 10) $$
と表したのと同じことです(ちなみに予測したいものが、 $x$ を与えられたときの $y$ である場合には、同時分布の設計の段階で最初から $x$ をgivenな形でモデル化することはままある話である)。さて、仮に $w, b$ を特定の値に確定させた、条件付き分布 $p(y \mid x, w _ c , b _ c)$ が欲しい場合にはどうすればいいのでしょうか。改めてそのような分布を Pyroで書き直す必要はありません。すでに作成した model()
関数からこれを生成することができます。その際に pyro.condition
関数を使います(pyro.poutine.condition
のエイリアス指定がされている)。
y_given_w_b = pyro.condition(model, data={"weight": 4, "bias": 1})
などとすれば良いのです。ここで返ってきた y_given_w_b
は自分で作ったモデル関数 model()
と同じ引数、同じ戻り値を持つ関数になっています。
今、このように得られた y_given_w_b
を
x_index = torch.linspace(-3, 3, N) with torch.no_grad(): y_given_x = y_given_w_b(x_index, None) plt.plot(x_index, y_given_x, "o")
としてみることで確かに $w = 4, b = 1$ のときの $y \sim {\rm Normal}(wx + b, 0.5)$ を再現することができました。
まとめと変分推論の例
さてここまでに登場した巻数を使って変分推論をしてみます。
pyro.clear_param_store() N = 100 w_true = torch.tensor(4.0) b_true = torch.tensor(1.0) x_data = torch.randn([N]) y_data = w_true * x_data + b_true + 0.5*torch.randn([N]) plt.plot(x_data, y_data, "o")
を手元のデータとします。 都合よく、線形モデルを仮定してみます。
def model(x_data, y_data): w = pyro.sample("weight", pyro.distributions.Normal(0, 10)) b = pyro.sample("bias", pyro.distributions.Normal(0, 10)) with pyro.plate("obs", N): y = pyro.sample("y", pyro.distributions.Normal(w * x_data + b, 0.5), obs=y_data) return y
先程は説明を割愛しましたが、引数にx_data, y_data
があるのは、モデルの対数確率の計算で必要となるからです。また、対数確率の計算では「尤度モデル」に相当する部分で、得られている観測実現値が何であるのかを指定するために、y = pyro.sample("y", pyro.distributions.Normal(w * x_data + b, 0.5), obs=y_data)
と obs
引数があることに注意しましょう(この辺はまだちゃんと説明していないが、そういうもんであるということで)。
変分推論をする場合は、変分近似分布を用意する必要がありますが、ここではPyroで準備されている近似分布を借用します(Pyroは確率変数がどれかを把握できるので、どの変数に対して近似分布を用意すべきかも自動で認識できるというわけです)。
guide = pyro.infer.autoguide.guides.AutoDelta(model)
ここでは変分近似として、w, b
にそれぞれデルタ分布を仮定します。デルタ分布はある一点にのみ確率(測度)を持つ分布であり、全くばらつきを持たない分布です。デルタ分布で変分推論をした場合には、ELBO最大化で得られるデルタ分布の確率(測度)が値を持つ点は、同時分布が最大となる点であり、それはすなわちMAP推定の解と一致する点が得られるということになります(言い換えれば、MAP推定は、パラメータにばらつきを一切考慮しないような変分推論の一部として、区別せずに使える。もちろんそれはフレームワークを使うという点での統一しているだけであり、そもそもMAP推定をしようと思ったときには数式上で変分推論を書き下す意味は特にない)。
変分パラメータも自動で設定されグローバルに配置されます。さて、ELBO最大化の変分推論を実施する際には、同時分布のモデルと変分分布が決まってしまえば、あとは一本道であるため、Pyroにはそれを実施するクラスが準備されています。
from pyro.infer import SVI, JitTrace_ELBO from pyro.optim import Adam adam_params = {"lr": 0.001, "betas": (0.95, 0.999)} optimizer = Adam(adam_params) svi = SVI(model, guide, optimizer, loss=JitTrace_ELBO()) n_steps = 5000 for step in range(n_steps): svi.step(x_data, y_data)
これだけで実施できてしまいます。素晴らしい。(ちなみに変分推論をPyTorchで自分で書くときは下記の記事のようにかなり長くなることを知っておくと、ありがたみがわかります)
さて、推論を終えたあとの変分パラメータを下記のように見てみましょう。
for name in pyro.get_param_store(): print("{}: {}".format(name, pyro.param(name))) # AutoDelta.weight: 4.024663925170898 # AutoDelta.bias: 0.9949770569801331
こうして得られた変分パラメータは、MAP推定の解であったので、ベイズ予測分布は単に条件付き分布
$$ p(y\mid w _ {map}, b _ {map}, x) = {\rm Normal} (y \mid w _ {map} x + b _ {map} , 0.5) $$
となります。条件付き分布といえば、pyro.condition
を使えば良く
w_map = pyro.param("AutoDelta.weight") b_map = pyro.param("AutoDelta.bias") y_given_w_b = pyro.condition(model, data={"weight": w_map, "bias": b_map})
で作れるのでした。こうして得られた y_given_w_b
は自分が作ったモデルと同じ引数、戻り値を持つ関数となっています。
x_index = torch.linspace(-3, 3, N) with torch.no_grad(): y_given_x = y_given_w_b(x_index, None)
で実際に確率的なサンプリングを実施できるというわけです。
もちろん、ちゃんと変分推論をしたりする場合にも、書き換えるのは guide
を作る部分だけです。ベイズ予測分布を構成する場合には、条件付き分布の作成ではなく、期待値計算の近似サンプリングが必要になりますが、実はそのようなクラスも準備されているので心配は無用です。
今回は基本をさらっとなぞるだけで終わりにします。