先回までのまとめ
フィッシャーの線形判別(1):フィッシャーの線形判別の概要
フィッシャーの線形判別(2):フィッシャーの線形判別の解を求める
フィッシャーの線形判別(3):2クラスの次元圧縮の実験
フィッシャーの線形判別(4):フィッシャーの線形判別の解の別の求め方
フィッシャーの線形判別を2クラスの場合について考えてきましたが、多クラスにも対応できるようです。が、手持ちの参考書の解説はかなりさっぱりしていて、あまりきちんと理解できていない。再勉強が必要。
多クラスへの一般化
入力次元を、次元に削減することを考え、クラス数はとします。ここで入力次元はクラス数より大きいとします。すると、クラス内共分散、クラス間共分散を以下のように一般化できるそうです。
そして求める解は
を最大化するで、の固有値の大きいものから順に、対応する固有ベクトルを個並べたものが解になるとのこと。
ということで、あまり理解していないので、中途半端に説明するのはやめて潔く結論だけ書いてしまいました。難しい数学背景があるようです。
理屈がわかっていなくてもコードは書けます。 4次元入力の3クラスを、2次元にしてみました。 グラフだけ見ると、きちんと分離度を保ちつつ次元を減らせていそうですが、、、。4次元入力がもはや図で確認できないので、うまくできたのかどうか判断しづらい。高次元のものを扱う場合、確からしさの確認ってどうやるのがいいんでしょうか。
今回のコードです。もうちょっとスマートに書きたいところ。
# フィッシャーの線形判別 多クラス import matplotlib.pyplot as plt import numpy as np # 各クラスのデータ数 N = 100 # クラス数 K = 3 # 入力次元数 D = 4 # 削減後の次元 D2 = 2 # ランダムシードを固定 np.random.seed(0) # 4クラス分のデータを作成 mean1 = np.array([-3, 3, -3, 0]) mean2 = np.array([3, 3, 3, -3]) mean3 = np.array([3, 3, 3, 3]) cov = np.eye(D, D) x1 = np.random.multivariate_normal(mean1, cov, N).T x2 = np.random.multivariate_normal(mean2, cov, N).T x3 = np.random.multivariate_normal(mean3, cov, N).T xx = np.array([x1, x2, x3]) # 各クラスの平均ベクトルを求める m1 = np.array([np.average(x1[i, :]) for i in range(D)]).reshape(D, 1) m2 = np.array([np.average(x2[i, :]) for i in range(D)]).reshape(D, 1) m3 = np.array([np.average(x3[i, :]) for i in range(D)]).reshape(D, 1) mx = np.array([m1, m2, m3]) # 全クラスの平均ベクトル m = (m1 + m2 + m3)/K # クラス間共分散行列を求める SB = np.zeros([D, D]) for i in range(K): SB += N*np.outer((mx[i] - m), (mx[i] - m)) # クラス内共分散行列を求める SW = np.zeros([D, D]) for k in range(K): for i in range(N): SW += np.outer((xx[k][:, i] - mx[k][:, 0]), (xx[k][:, i] - mx[k][:, 0])) # Wを求める(SW^(-1)SBの固有ベクトルの中から、固有値の大きいものから次元数分使用) lam, v = np.linalg.eig(np.dot(np.linalg.inv(SW), SB)) v = v[:, np.argsort(lam)[::-1]] W = v[:, 0:D2] # 結果確認 y1 = np.dot(W.T, x1) y2 = np.dot(W.T, x2) y3 = np.dot(W.T, x3) plt.scatter(y1[0, :], y1[1, :], color="blue", alpha=0.5) plt.scatter(y2[0, :], y2[1, :], color="red", alpha=0.5) plt.scatter(y3[0, :], y3[1, :], color="green", alpha=0.5) plt.show()