機械学習に詳しくなりたいブログ

機械学習や数学について勉強した内容を中心に書きます。100%趣味です。記事は数学的に厳密でなかったり誤りを含んでいるかもしれません。ご指摘頂ければ幸いです。

入力を非線形変換したパーセプトロン

パーセプトロンで単純な識別の実験をし、そこでは訓練データ\mathbf x, tが与えられたとき、


  y(\mathbf x) = f(\mathbf{w}^{T} \mathbf{x}) \tag{1}

で表されるモデルを考えました。線形に識別できる問題しか扱えず、応用が効かなさそうですが、入力\mathbf{x}を非線形変換をした


  y(\mathbf x) = f(\mathbf{w}^{T} \phi(\mathbf{x})) \tag{2}

を考えると柔軟な識別が可能になります。

実験してみました。 f:id:opabinia2:20180708113430p:plain

ここでは、入力を\mathbf{x}=(x_0,x_1)^{T}としたとき、\phi(\mathbf{x}) = (x_0^{2},x_1^{2} ,x_0 x_1 ,x_0, x_1 , 1)^{T}のような変換としています。曲線で分離されるような分布もきちんと識別できてますね。ただやはり、非線形変換した\phi(\mathbf{x})が、\mathbf{w}の線形で識別できる場合でないと解けないはずです。

今回のコードです。

# 入力を非線形変換したパーセプトロン 
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm


# クラス分け
def create_data(x0, x1):
    if (x0-0.5)**2 + 0.5*x1**2 - 1 > 0:
        return -1.0
    else:
        return 1.0


# 識別結果を返す
def f(x0, x1):
    if np.dot(w.T, phi(x0, x1)) > 0:
        return -1.0
    else:
        return 1.0


# 非線形変換
def phi(x0, x1):
    return np.array([x0**2, x1**2, x0*x1, x0, x1, 1])


# 全データ数
N = 300

# 入力次元数
D = 2

# クラス数
K = 2

# 学習率
Eta = 0.1

# ランダムシードを固定
np.random.seed(0)

# 2クラス分のデータを作成
x = np.random.uniform(-2, 2, [D, N])
t = np.empty([N])
t = np.vectorize(create_data)(x[0, :], x[1, :])


# 重みベクトルwを初期化
# Φ(x) = x0^2 + x1^2 + x0x1 + x0 + x0 + 1 とする(求める係数は6個)
w = np.random.uniform(-1, 1, 6)

# データを非線形変換する
phi_x = np.array([phi(x[0, i], x[1, i]) for i in range(N)])

# 全データに対して更新処理を行い、誤りがなくなるまで繰り返す
while True:
    # 更新処理の順をランダムで決める
    index = np.random.permutation(np.arange(0, N, 1))

    break_flag = True
    # 更新処理
    for i in range(N):
        if t[index[i]] * np.dot(w.T, phi_x[index[i], :]) < 0:
            w = w + Eta * t[index[i]] * phi_x[index[i], :]
            break_flag = False

    if break_flag:
        break

plt.xlim(-2, 2)
plt.ylim(-2, 2)

# グラフの色分け
a, b = np.meshgrid(np.linspace(-2, 2, 1000), np.linspace(-2, 2, 1000))
vec_f = np.vectorize(f)
plt.contourf(a, b, vec_f(a, b), alpha=0.2, cmap=cm.coolwarm)

plt.scatter(x[:,np.where(t==1)][0], x[:,np.where(t==1)][1], color="blue", alpha=0.5)
plt.scatter(x[:,np.where(t==-1)][0], x[:,np.where(t==-1)][1], color="red", alpha=0.5)
plt.show()