機械学習に詳しくなりたいブログ

機械学習や数学について勉強した内容を中心に書きます。100%趣味です。記事は数学的に厳密でなかったり誤りを含んでいるかもしれません。ご指摘頂ければ幸いです。

線形識別を最小二乗法で解く

線形識別の最小二乗法よる解

線形識別の最小二乗法による解は、線形識別 最小二乗法の解の導出で導出しました。 今回はそれを用いて実際にいくつか識別させてみたいと思います。

識別結果

まずは単純な直線y=xで分離される識別を解いてみました。点が教師データで、塗りつぶされている領域が識別結果です。 ちゃんと識別できていますね。左側だと、ちょっと分離面がズレていますが、教師データ数を増やしてみたらきっちり分離できました。

しかし線形識別を最小二乗法で解くのは問題があるようです。なぜなら、最小二乗法とは何であったかを思い出してみると、真値からの誤差が正規分布に従っているときに使えるものでした。*1 線形識別における真値とは、正解クラスのみが1である1ofK符号で、2クラスの場合なら(1,0)または(0,1)です。*2 この符号から正規分布の誤差が加わったものが識別関数の結果と考えることはできませんので、最小二乗法は適していないということです。教師データの中に大きく外れている値が含まれていると、以下のように簡単に識別がズレてしまします。(外れ値はグラフ外のため表示されていません)外れている値といっても、y=xで適切に分類される正しい教師データですが、二乗誤差が大きくなってしまい、その影響を受けてしまうんですね。

外れている値であれば、事前に選別して取り除くことができると思いますが、教師データに偏りがあっても以下のように全然識別できませんでした。 これもうまくいかないのは同じ理屈です。

さらに、クラス数を増やしてみると論外なほど識別がうまくいかない場合があります。

ということで、線形識別に対して最小二乗法は使えない、ということが確認できました。

線形識別勉強のまとめ:線形識別の目次

今回のコードです。3クラスでは全くうまくいかない、と本文で書きましたが、今回のコードのような教師データの与え方では、データ数を増やしたらそれっぽくなりました。いずれにしても使えないことには違いありません。

# 最小二乗法による線形識別
import matplotlib.pyplot as plt
import numpy as np


# 教師データの属するクラスを返す
def teaching(x, y):
    if y > x + 0.3:
        return 0
    if y > x - 0.3:
        return 1
    return 2


# 判別したクラス返す
def f(x0, x1):
    return np.argmax(np.dot(w.T, np.array([1.0, x0, x1])))


# ランダムシードを固定
np.random.seed(0)

# 訓練データ数
N = 40

# 入力次元数(ダミー入力を含めて)
M = 3

# クラス数
K = 3

# 教師データ作成
T = np.zeros([N, K])
t = np.empty(N)
xt = np.random.uniform(0, 1, N)
yt = np.random.uniform(0, 1, N)

# 外れ値
# xt[0] = 8
# yt[0] = 0.2

for i in range(N):
    t[i] = teaching(xt[i], yt[i])
    T[i, int(t[i])] = 1


# Xを作る
X = np.hstack([np.ones([N, 1]), xt.reshape([N, 1]), yt.reshape([N, 1])])

# 係数wを求める
w = np.linalg.solve(np.dot(X.T, X), np.dot(X.T, T))

# グラフ表示用の判別結果
a, b = np.meshgrid(np.linspace(0, 1, 1000), np.linspace(0, 1, 1000))
vec_f = np.vectorize(f)
c = vec_f(a, b)

plt.xlim(0, 1)
plt.ylim(0, 1)
plt.scatter(xt[t == 0], yt[t == 0], color="blue", alpha=0.5)
plt.scatter(xt[t == 1], yt[t == 1], color="green", alpha=0.5)
plt.scatter(xt[t == 2], yt[t == 2], color="red", alpha=0.5)
plt.contourf(a, b, c, alpha=0.2)
plt.show()