機械学習に詳しくなりたいブログ

機械学習や数学について勉強した内容を中心に書きます。100%趣味です。記事は数学的に厳密でなかったり誤りを含んでいるかもしれません。ご指摘頂ければ幸いです。

線形識別の概要

今回から線形識別を勉強していこうと思っています。

何らかの入力\mathbf xに対し、識別関数を通すことで、\mathbf xが属するクラスは何なのかを識別するというもの。線形回帰よりはちょっと機械学習っぽい感じですね。識別関数は、例えば


y(\mathbf x) = \mathbf w ^{T} \mathbf x + w_0 \tag{1}

のように表されます。入力が2次元で2クラスに分類する場合を考えれば、入力\mathbf xを識別関数に通した結果、y(\mathbf x) \ge 0ならばクラスC_1そうでなければC_2と識別します。式(1)は、定数の入力1を使えば


y(\mathbf x) = \mathbf w ^{T} \mathbf x \tag{2}

と書けます。線形回帰でも多項式を式(2)のような形で表していたのと同じですね。(参考:こちらの式(1),(3)) 式(1)や(2)の表現だと、直感的に線分を思い浮かべてしまうのですが、\mathbf xが2次元なら y(\mathbf x) は平面を表しています。下図のような感じです。黒い平面が識別関数で、正なら赤色のクラス、負なら青色のクラスに分類しています。

f:id:opabinia2:20180425205553p:plain

同じように3クラスに分類する問題を考えるとうまくいきません。例えば2つの識別関数 y_0(\mathbf x),  y_1(\mathbf x) を用いると、各々の出力の正負の組み合わせにより、4つのクラスが生じてしまうためです。正負で判断するのではなく、各クラスに対応する識別関数 y_i(\mathbf x)を用意し、入力を各識別関数に通した結果、出力が最大のクラスに属する、と判断することで対応可能です。適当な識別関数を作って、3クラス、5クラスをこの方法で分類するとどういう領域となるのか実験してみました。

f:id:opabinia2:20180425210756p:plain f:id:opabinia2:20180425210759p:plain

なんか線形識別を扱う記事で見たことのあるような雰囲気の図ができました!この方法で分類できる領域は必ず凸領域(領域内の任意の2点を結んだ直線は、常に領域内である)になります。けっこう証明は簡単です。ある領域kの任意の2点\mathbf x_A,\mathbf x_Bを考えると、この2点を結ぶ直線上にある任意の点\hat{\mathbf x}は、0 \le \lambda \le 1である\lambdaを用いて

 \hat{\mathbf x} = \lambda \mathbf x_A + (1- \lambda) \mathbf x_B \tag{3}

と書けます。*1 そして識別関数(式(2))の線形性より

 y_k(\hat{\mathbf x} )= \lambda y_k(\mathbf x_A) + (1- \lambda) y_k(\mathbf x_B )\tag{4}

です。\mathbf x_A,\mathbf x_Bは領域k内の点でしたから、その他の全ての識別関数に対して y_k(\mathbf x_A) \ge y_i(\mathbf x_A),y_k(\mathbf x_B) \ge y_i(\mathbf x_B)です。よって、式(4)より y_k(\hat{\mathbf x} ) \ge y_i(\hat{\mathbf x} )であり、領域内の2点を内分する点は同じ領域に属することとなります。つまり凸領域です。

凸領域にしか分類できないとすると、あまり実用的でなさそうですが、基底変換することで分類面は柔軟にできるようです。次回は最小二乗法を使って実際に識別の問題を解いてみたいと思います。次回:線形識別 最小二乗法の解の導出

線形識別勉強のまとめ:線形識別の目次

識別領域の確認に作ったコードです。

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm


def y(x1, x2):
    y1 = 5*x1 - x2 + 3
    y2 = -x1 + x2
    y3 = x1 - 5*x2
    y4 = 6*x1 - 2*x2 - 2
    y5 = -5*x1 - 6*x2 + 1
    return np.argmax([y1, y2, y3, y4, y5])

x1 = np.linspace(-5, 5, 1000)
x2 = np.linspace(-5, 5, 1000)
X1, X2 = np.meshgrid(x1, x2)

vec_y = np.vectorize(y)

Y = vec_y(X1, X2)
plt.contourf(X1, X2, Y, cmap=cm.coolwarm)
plt.show()

*1:高校で学ぶ2点間を内分する点の関係より