はじめに
これはもはやただの備忘録です。 Pyroに無い機能をこっちで動かす確認をしたかったというだけでございます。
import tensorflow as tf import tensorflow_probability as tfp from tensorflow_probability import distributions as tfd from tensorflow_probability import bijectors as tfb import numpy as np import matplotlib.pyplot as plt plt.style.use("seaborn")
を前提とします。
データ
def toy_data(): sigma1 = 1.5 mu1 = 10.0 dist1 = tfd.Normal(loc=mu1, scale=sigma1) sigma2 = 3.0 mu2 = 2.0 dist2 = tfd.Normal(loc=mu2, scale=sigma2) return tf.concat([dist1.sample(100), dist2.sample(200)], axis=0)
やっつけでデータを作成しております。個々のデータは単に正規分布で、それが1:2の割合で混ざっています。 見かけ上は負担率1/3の混合正規分布のようなヒストグラムになります。
モデル
ということで、正規分布2つが混ざっている混合モデルを適当な事前分布を仮定して作ってみます。 データを見ればおおよそ各々の平均や分散が見積もれるので、その付近にばらつきを持った事前分布を仮定しました。
標準偏差は正の値であるべきですので、ガウス分布から生成した log_sigma
を tf.math.exp
に食わしてパラメータとして与えます。
負担率は 0から1の値ですので、logit
を作ってから tf.math.sigmoid
に食わせます。再パラメータ化はMCMCの結果に影響するので注意が必要です。
あまり非線形性の強い変換(曲率の大きな変換)はその付近でサンプリングが敏感に変わってしまうことになるでしょう。
再パラメータ化せずに、分散に対しては正の値しかサンプリングしない逆ガンマ分布であったり、負担率であれば0から1しか値を取らないベータ分布を使う方法もあります。
root = tfd.JointDistributionCoroutine.Root def model(): mu1 = yield root(tfd.Normal(loc=3.0, scale=3)) log_sigma1 = yield root(tfd.Normal(loc=0, scale=2)) sigma1 = tf.math.exp(log_sigma1) mu2 = yield root(tfd.Normal(loc=10, scale=3)) log_sigma2 = yield root(tfd.Normal(loc=0, scale=2)) sigma2 = tf.math.exp(log_sigma2) logit = yield root(tfd.Normal(loc=0, scale=1.0)) prob = tf.math.sigmoid(logit) x = yield tfd.Sample( tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=[prob, 1-prob]), components_distribution=tfd.Normal( loc=[mu1, mu2], scale=[sigma1, sigma2], ) ), sample_shape=300 ) joint = tfd.JointDistributionCoroutine(model)
上記のモデルの作り方の基本は下記で
遷移核
サンプリングしたい分布は同時分布に対して、観測データ sample
を与えた状態にすれば得られます。
サンプリングを実施するための遷移核にはNUTSを使い、レプリカ交換モンテカルロ法のクラスでラッピングします。
この時、make_kernel_fn
にはtarget_log_prob_fn
と書かれているサンプリングしたい分布と、ランダムシードを引数にしておきます。
なぜかランダムシードの引数がないと、後で怒られます。
レプリカ交換モンテカルロのクラスでラップするときに、seed=tf.random.set_seed()
を利用しなければなりません。intを直接与えず、こうしておかないとTF2.0が怒ります。謎。
しかもintを与えた時に出てくるエラーメッセージは、存在しない架空の(過去にはあったのだろうが)メソッドを使えと表示されるため、解決方法を見つけるまでハマりました。
sample = toy_data() def unnormalized_log_prob(mu1, log_sigma1, mu2, log_sigma2, logit): return tf.reduce_mean(joint.log_prob([mu1, log_sigma1, mu2, log_sigma2, logit, sample])) def make_kernel_fn(target_log_prob_fn, seed): return tfp.mcmc.NoUTurnSampler( target_log_prob_fn=target_log_prob_fn, step_size=0.05, seed=seed ) remc = tfp.mcmc.ReplicaExchangeMC( target_log_prob_fn=unnormalized_log_prob, inverse_temperatures=[tf.constant(0.2), tf.constant(0.2), tf.constant(0.2), tf.constant(0.2), tf.constant(0.2)], make_kernel_fn=make_kernel_fn, seed=tf.random.set_seed(1) )
あとは回すだけ
@tf.function() def run_chain(): with tf.device("/gpu:0"): init_state = list(joint.sample()[:-1]) chains_states, kernels_results = tfp.mcmc.sample_chain( num_results=1000, num_burnin_steps=300, current_state=init_state, kernel=remc, parallel_iterations=50 ) return chains_states, kernels_results chain_states, kernel_results = run_chain()
言っておきますがめちゃくちゃ遅いです。いや、一次元なんだからGPU使う必要はなかったのだろうか…?よくわからないけど、どっちにしてもアクビが出るほど遅かったです(多峰が相手になったときもNUTSを初期値変えてサンプリングを複数回繰り返して事後分布を結合したほうがマシなんじゃないか?くらい遅い)。
ひとまず備忘録なのでここまで。