機械学習に詳しくなりたいブログ

機械学習や数学について勉強した内容を中心に書きます。100%趣味です。記事は数学的に厳密でなかったり誤りを含んでいるかもしれません。ご指摘頂ければ幸いです。

棄却サンプリング

確率分布p(x)からサンプリングしたいが、正規分布や一様分布のようにライブラリが提供されていない場合にどうするか?逆関数法もその1つですが、棄却サンプリングはより直感的な手法です。

棄却サンプリングの原理

確率分布は\displaystyle \int p(x)dx = 1を満たすべきですが、正規化定数が不明で、


\displaystyle p(x) = \frac{1}{Z_p}\tilde{p}(x) \tag{1}

\tilde{p}(x)のみ既知である場合を考え、この\tilde{p}(x)からサンプリングしたいとします。もちろん正規化定数も既知でも問題ないのですが、おそらく不明の場合のほうが一般的だろうと思われます。

ここで全てのxに対して


kq(x) \ge \tilde{p}(x) \tag{2}

を満たす定数kおよびq(x)を定めます。このq(x)は与えられたものではなく、サンプリング時に設定するもので提案分布と呼びます。提案分布はサンプリングできるものを選択します。


視覚的には上図のように、\tilde{p}(x)を覆うようにkおよびq(x)を定めます。具体的なサンプリングの手順はとてもシンプルです。

棄却サンプリングの手順
  1. q(x)からzをサンプリングする
  2. \displaystyle \frac{\tilde{p}(z)}{kq(z)}の確率でzを採択し、採択できなければ再びq(x)から新たなサンプルをとり、採択できるまで繰り返す。


手順2によって採択できない、つまり棄却されるサンプル数を少なくするため、提案分布は\tilde{p}(x)に近いことが望ましい。また、手順2で確率として解釈できるために、つまり[0,1]の範囲であるために式(2)の条件が必要なんですね。この手順によって、\tilde{p}(x)がどんな形状であろうとサンプリング可能です。提案分布によってサンプリングされやすい領域があったとしても採択確率が小さければ、つまりそこで\tilde{p}(x)が小さければ、棄却されやすくなるだけです。

非常にシンプルでわかりやすい手順です。ああ、こんなことでいいんだって感じ。しかしながら万能ではなく、特に高次元では適切な提案分布を定めることが難しく採択確率が低くなってしまうなどの問題があるようです。

サンプリング実験

さてコードを書いて実験してみました。 p(x) =\sin^{2} 0.5 \pi xからサンプリングしてみます。区間は[-1,1]とします。わかりやすく提案分布は一様分布としました。つまりq(x)=0.5です。p(x)の最大値は1になるので、定数k=2とします。以下のグラフが結果です。

p(x)からうまくサンプリングできていますね。

今回のコードです。なお今回はp(x)が対象区間で積分すると1になるようにしましたが、そうでない場合は、このコードのままグラフを描くとサンプリングが失敗しているように見えます。最初はコードにミスがあると思って悩んでしまいました。下のグラフは p(x) =\sin^{2} 0.3 \pi xとした場合です。p(x)とサンプルのヒストグラムを1つのグラフに描くと形状一致していないように見えてしまいます。

# 棄却サンプリングの実験

import numpy as np
import matplotlib.pyplot as plt


def p(x):
    return np.sin(0.5*np.pi*x)**2


# 棄却サンプリング
def sampling():
    k = 2
    # 採択するまでループ
    while True:
        # 提案分布q(z)からサンプリング
        z = np.random.uniform(-1, 1)
        # [0,kq(z)]の一様分布からサンプリング
        u = k*np.random.uniform(0, 0.5)
        # 棄却するか判定
        if p(z) > u:
            return z


x = np.linspace(-1, 1, 500)
y = p(x)

# サンプル数
N = 100000

samples = np.array([sampling() for i in range(N)])

plt.xlim(-1, 1)
plt.hist(samples, bins=50,normed=True,alpha=0.5)
plt.plot(x, y)
plt.show()