機械学習に詳しくなりたいブログ

機械学習や数学について勉強した内容を中心に書きます。100%趣味です。記事は数学的に厳密でなかったり誤りを含んでいるかもしれません。ご指摘頂ければ幸いです。

フィッシャーの線形判別(4)

フィッシャーの線形判別(2)


\displaystyle  J(\mathbf w) = \frac{\mathbf w^T \mathbf{S}_B \mathbf w}{\mathbf w^T \mathbf{S}_W \mathbf w} \tag{1}

を最大にする\mathbf wの解を求めましたが、今回はそれとは別の求め方を考えます。

式(1)において、解を一意に定めるため、\mathbf{w}^T \mathbf{S}_W \mathbf w = 1の制約の元での最大を考えます。*1 ラグランジュの未定乗数法により


L = \mathbf{w}^T \mathbf{S}_B \mathbf w - \lambda (\mathbf{w}^T \mathbf{S}_W \mathbf w - 1) \tag{2}

とし、


\displaystyle \frac{\partial L}{\partial \mathbf w} = 2 \mathbf{S}_B \mathbf w - 2 \lambda \mathbf{S}_W \mathbf w = 0\tag{3}

を解けばよいことになります。式(3)より、


 \mathbf{S}_W^{-1} \mathbf{S}_B \mathbf w = \lambda \mathbf w \tag{4}

となります。式(4)は、 \mathbf{S}_{W}^{-1} \mathbf{S}_Bの固有ベクトルが\mathbf w、固有値が\lambdaであることを意味しています。(参考:固有値と固有ベクトル)式(4)を


 \mathbf{S}_B = \lambda  \mathbf{S}_W \tag{5}

として式(1)に代入すれば、


\displaystyle  J(\mathbf w) = \lambda \tag{6}

となります。よって、J(\mathbf w)の最大値は \mathbf{S}_{W}^{-1} \mathbf{S}_Bの最大固有値であり、その固有ベクトルが求める解となります。

ということで、コードを書いて実験してみました。結果はフィッシャーの線形判別(3)と全く同じになりましたから省略します。

# フィッシャーの線形判別

import matplotlib.pyplot as plt
import numpy as np

# 各クラスのデータ数
N = 300

# 入力次元数
D = 2

# クラス数
K = 2

# ランダムシードを固定
np.random.seed(0)

# 2クラス分のデータを作成
mean1 = np.array([0, 2])
mean2 = np.array([0, -2])
cov = [[1.0, -0.7], [-0.7, 1.0]]
x1 = np.random.multivariate_normal(mean1, cov, N).T
x2 = np.random.multivariate_normal(mean2, cov, N).T

# 各クラスの平均ベクトルを求める
m1 = np.array([np.average(x1[0, :]), np.average(x1[1, :])]).reshape(D, 1)
m2 = np.array([np.average(x2[0, :]), np.average(x2[1, :])]).reshape(D, 1)

# クラス内共分散行列を求める
Sw = np.zeros([D, D])
for i in range(N):
    Sw = np.dot((x1[:, i].reshape(D, 1)-m1), (x1[:, i].reshape(D, 1)-m1).T) + \
        np.dot((x2[:, i].reshape(D, 1)-m2), (x2[:, i].reshape(D, 1)-m2).T) + Sw

# クラス間共分散行列を求める
SB = np.outer((m2 - m1), (m2- m1))

# 固有値と固有ベクトルを求める
lam, v = np.linalg.eig(np.dot(np.linalg.inv(Sw),SB))

# 最大固有に対応する固有ベクトルが求めるw グラフ表示用に大きさ調整
w = v[:, np.argmax(lam)]
w = (w/np.linalg.norm(w)).reshape(D, 1)

plt.scatter(x1[0, :], x1[1, :], color="blue", alpha=0.5)
plt.scatter(x2[0, :], x2[1, :], color="red", alpha=0.5)
# 射影の方向は直線wに対して垂直方向
plt.quiver(0, 0, w[1], -w[0], angles="xy", units="xy", color="black", scale=0.5)
plt.show()

y1 = np.dot(w.T, x1)
y2 = np.dot(w.T, x2)

plt.hist(y1[0], bins=30, color="blue", alpha=0.5)
plt.hist(y2[0], bins=30, color="red", alpha=0.5)
plt.show()

*1:この制約は、求められる\mathbf wの方向には影響がないということ、、、だと思いますが参考書は当たり前のものとして進んでいくので自信がない