先回までにメトロポリス・ヘイスティングズ法によるMCMCの実装を行ってきました。(参考:マルコフ連鎖モンテカルロ法 - メトロポリス・ヘイスティングズ法(1))これによって様々な分布からサンプリングできるようになったので、今回はMAP推定をMCMCで解いてみたいと思います。MAP推定については、線形回帰をMAP推定で解くに書いています。求めたい係数の事前分布を導入し、ベイズの定理を使って、訓練データが与えられた時の事後分布を求め、その最大値を解とする、というものでした。MCMCを使えば、解析的に解けない問題でも事後分布さえわかっていればそこからサンプリングしてしまえば良いということですね。
さて、いくつかの前提条件を線形回帰をMAP推定で解くと同様のものとすれば、事後分布は
と表されます。この式からMCMCを使って直接サンプリングし、最大値を求めてみます。訓練データは線形回帰でやってきたものと同じく、正弦波に正規分布の誤差が加わっているものにし、3次の多項式
で回帰することとします。
早速結果です。の事後分布は以下のようになりました。
そしてこの係数を使った近似結果がこちら。
ちゃんと近似できていることが確認できました!
今回のコードです。
# MAP推定をMCMC法を使って解く import numpy as np import matplotlib.pyplot as plt # y = w0*x^0+....wM*x^M を、引数xの配列数分求める def y(w, x, M): X = np.empty((M + 1, x.size)) for i in range(M + 1): X[i, :] = x ** i return np.dot(w.T, X) # 提案分布 def q(x): mean = np.zeros(M+1) cov = 0.5*np.eye(M+1) z = np.random.multivariate_normal(mean, cov).T return x + z # サンプルする関数 def p(w): return np.exp(-np.dot((t - np.dot(X, w)).T, (t - np.dot(X, w)))/(2*(v**2)) - 0.5*np.dot(np.dot(w.T, np.linalg.inv(s0)), w)) # MCMCでMAP推定値を求める def MAP_MCMC(x0): # サンプル候補 w0 = q(x0) if p(w0.reshape(M+1, 1))/p(x0.reshape(M+1, 1)) > np.random.uniform(0, 1): return w0 else: return x0 # ランダムシードを固定 np.random.seed(0) # 多項式の最大べき乗数(x^0+...+x^M) M = 3 # 訓練データ数 N = 10 # 訓練データの列ベクトル x = np.linspace(0, 1, N).reshape(N, 1) # 訓練データtの列ベクトル t = np.sin(2*np.pi*x.T) + np.random.normal(0, 0.2, N) t = t.reshape(N, 1) # 行列Xを作成 X = np.empty((N, M+1)) for i in range(M+1): X[:, i] = x.reshape(1, N) ** i # 事前分布の共分散行列 v0 = 1000 s0 = v0 * np.eye(M+1, M+1) # 真値に加わっている正規分布ノイズの分散 v = 0.2 # サンプル数 N = 100000 # バーンイン B = 10000 # MCMCの初期値 x0 = np.zeros(M+1) samples = np.empty([N, M+1]) # MCMC for i in range(N): x0 = MAP_MCMC(x0) # print(i) samples[i, :] = x0 # print(np.average(samples[:,0])) # print(np.average(samples[:,1])) # print(np.average(samples[:,2])) # print(np.average(samples[:,3])) # # plt.hist(samples[B:,0],bins=100,normed=True) # plt.show() # plt.hist(samples[B:,1],bins=100,normed=True) # plt.show() # plt.hist(samples[B:,2],bins=100,normed=True) # plt.show() # plt.hist(samples[B:,3],bins=100,normed=True) # plt.show() w0 = np.average(samples[B:,0]) w1 = np.average(samples[B:,1]) w2 = np.average(samples[B:,2]) w3 = np.average(samples[B:,3]) w = np.array([w0,w1,w2,w3]).reshape(M+1,1) # 求めた係数wを元に、新たな入力x2に対する予測値yを求める x2 = np.linspace(0, 1, 100) y1 = y(w, x2, M) # 結果の表示 plt.xlim(0.0, 1.0) plt.ylim(-1.5, 1.5) plt.scatter(x, t) plt.plot(x2, y1.T) plt.show()