機械学習に詳しくなりたいブログ

機械学習や数学について勉強した内容を中心に書きます。100%趣味です。記事は数学的に厳密でなかったり誤りを含んでいるかもしれません。ご指摘頂ければ幸いです。

線形回帰をMAP推定で解く

線形回帰をMAP推定で解きます。MAPとは、Maximum a posteriori:最大事後確率の略です。モデルのパラメータwがある分布に従って生起するとして事前分布を設定します。そして得られた訓練データとベイズの定理を用いて、設定した事前分布からwの事後分布を求めます。そしてこの事後分布から、最も生起する確率が高いwを求める解として採用します。

図で書くとこのようなイメージです。設定したwの青色の事前分布と訓練データによって、緑色のような事後分布が得られたとします。そしてこの分布において最も確率の高い値w^{\prime}を解とします。最小二乗法において、過学習が起きるときはパラメータの値が大きい傾向にありました。そういう値をとる確率は低い、という事前分布を与えておけば過学習抑制の効果が期待できます。

MAP推定で解を求める手順

  1. 係数wの事前分布を設定する。
  2. 与えられた訓練データにより、ベイズの定理を使って係数wの事後分布が定まる
  3. 事後分布の最大値(MAP:最大事後確率)をwの値として採用する
  4. 新たな入力に対する予測値が求まる

パラメータの事前分布を p(\mathbf w)とすれば、訓練データ \mathbf tが与えられたときの事後分布 p(\mathbf w|\mathbf t)ベイズの定理


\displaystyle P(X|Y) = \frac{P(Y|X)P(X)}{P(Y)} \tag{1}

より


\displaystyle p(\mathbf w|\mathbf t) = \frac{p(\mathbf t|\mathbf w)p(\mathbf w)}{p(\mathbf t)} \tag{2}

と書けます。訓練データには真値に平均0、分散 \sigma^{2}の正規分布 N(0,\sigma^{2})に従う誤差が加わっていると仮定すれば、尤度 p(\mathbf t|\mathbf w)は、


\displaystyle p(\mathbf t|\mathbf w) = \frac{1}{\sqrt{2\pi\sigma^2}}\exp\left( \frac{ -\| \mathbf t - \mathbf X \mathbf w\|^2 }{ 2 \sigma ^2}\right) \tag{3}

です。(ここで \mathbf X \mathbf w最小二乗法の解の導出の式(4)を参照ください)なぜ尤度がこう書けるか?わかっている人にはすごくクドい話になりますが、僕は悩んでしまったので説明します。まず尤度という解釈はさておき、 p(\mathbf t|\mathbf w)の意味を考えます。これは \mathbf wが与えられたときの \mathbf tの確率分布です。 \mathbf tは真の値 \mathbf yから N(0,\sigma^{2})の誤差\mathbf{\epsilon}が加わっていると仮定してます。つまり

 
\mathbf t = \mathbf y + \mathbf{\epsilon} = \mathbf X \mathbf w + \mathbf{\epsilon} \tag{4}

です。 \mathbf Xは訓練データのことですから観測済みの値です。そして \mathbf w p(\mathbf t|\mathbf w)の条件に入っています。よって \mathbf X \mathbf wは揺らがない値、つまり定数になります。そして \epsilon N(0,\sigma^{2})の誤差でした。ですから、 \mathbf tは正規分布 N(\mathbf X \mathbf w,\sigma^{2})に従うはずです。よって式(3)のように書けます。で、これは \mathbf tの確率密度関数ではなく \mathbf wの関数とみなしたときに尤度として解釈できます。(観測済みの値がなぜ確率密度関数の式で書けるのか?で少し悩みました。確率密度関数ではなく尤度として解釈すべきなんですね、、、。*1

さて、事前分布 p(\mathbf w)を平均 \mathbf m_{0}、共分散 \mathbf S_0の多変量正規分布と設定すれば、


\displaystyle p(\mathbf{w})= \frac{1}{\sqrt{(2\pi)^{n} \det \mathbf S_0}}\exp\left\{ -\frac{1}{2}(\mathbf{w} - \mathbf{\mathbf m_0})^T \mathbf S_{0}^{-1}(\mathbf{w} - \mathbf{\mathbf m_0}) \right\} \tag{5}

と書けます。*2 よって式(3)(5)を式(2)に代入すれば、事後分布 p(\mathbf w|\mathbf t)は、


\displaystyle p(\mathbf w|\mathbf t) \propto \exp\left\{ \frac{ -\| \mathbf t - \mathbf X \mathbf w\|^2 }{ 2 \sigma ^2} -\frac{1}{2}(\mathbf{w} - \mathbf{\mathbf m_0})^T \mathbf S_{0}^{-1}(\mathbf{w} - \mathbf{\mathbf m_0}) \right\} \tag{6}

となります。式(2)における分母や、式(3)(5)における定数倍項は解に影響しないため省略しています。

さて、式(6)を解いていくのですが、長くなるので導出はこちらで書きました。 www.iwanttobeacat.com

今回は結果だけ。


\displaystyle \mathbf S_N^{-1} = \frac{\mathbf X^T \mathbf X}{\sigma^{2}} + \mathbf S_0^{-1} \tag{7}

\displaystyle \mathbf m_N = \mathbf S_N \left( \frac{ \mathbf X^T \mathbf t}{\sigma^{2}}  + \mathbf S_0^{-1} \mathbf m_0 \right) \tag{8}

式(8)の \mathbf m_Nが事後分布を最大化する解になります。

ここで \mathbf wの事前分布を、平均 \mathbf m_0=0、共分散 \mathbf S_0 = \sigma_0^{2}\mathbf I、つまり各パラメータ間に相関はなく、全て等しい分散であると仮定してみます。式(8)に式(7)をかけて、 \mathbf S_Nの項を消せば、


\displaystyle \left( \frac{\mathbf X^T \mathbf X}{\sigma^{2}} + \mathbf S_0^{-1} \right) \mathbf m_N = \mathbf S_N^{-1}\mathbf S_N \left( \frac{ \mathbf X^T \mathbf t}{\sigma^{2}}  + \mathbf S_0^{-1} \mathbf m_0 \right) \tag{9}


\displaystyle \left( \frac{ \mathbf X^T \mathbf X}{\sigma^{2}} + \frac{ \mathbf I}{\sigma_0^{2}}\right)\mathbf m_N = \frac{ \mathbf X^T \mathbf t}{\sigma^{2}}  \tag{10}

となります。両辺に \sigma^{2} をかけて、 \displaystyle \lambda = \frac{\sigma^{2}}{\sigma_0^{2}}とおけば、


\begin{eqnarray*}

\mathbf X^T \mathbf X \mathbf m_N + \lambda \mathbf m_N &=& \mathbf X^T \mathbf t \tag{11} \\
(\mathbf X^T \mathbf X + \lambda \mathbf I) \mathbf m_N &=& \mathbf X^T \mathbf t \tag{12} \\
\mathbf{m_N} &=& (\mathbf{X}^{T}\mathbf{X} + \lambda \mathbf{I})^{-1} \mathbf{X}^{T}\mathbf{t} \tag{13}

\end{eqnarray*}

となり、式(13)は正則化最小二乗法で求めた式と全く同じになります。

最小二乗法では \mathbf wの取りうる全ての範囲で最適解を探してしまうため過学習が発生します。これに対して正則化項を加えることで、 \mathbf wの解に制約条件を課すことができ過学習が防ぐことができました。今回 \mathbf wの事前分布を導入することは、その意味からも明らかですが、制約条件を課していることとなり、分散が等方的であるなど特別な条件下においてはMAP推定は正則化最小二乗法と等価になります。

、、、ということで一応コードも書いて実験してみましたが、正則化最小二乗法と同じなので今回は省略します。事前分布の分散を非常に大きな値に設定すると過学習がやっぱり起きるよねっていう確認をしたくらい。