機械学習に詳しくなりたいブログ

機械学習や数学について勉強した内容を中心に書きます。100%趣味です。記事は数学的に厳密でなかったり誤りを含んでいるかもしれません。ご指摘頂ければ幸いです。

多クラスロジスティック回帰をニュートン法で解く

多クラスロジスティック回帰のヘッセ行列を求めることができましたので(参考:多クラスロジスティック回帰における交差エントロピー誤差のヘッセ行列)、今回は実際にニュートン法を使って解いてみたいと思います。

まずはおさらい。多クラスロジスティック回帰の勾配は


 \left(
    \begin{array}{c}
      \displaystyle  \sum_{n=1}^{N} (y_{n1}-t_{n1})\mathbf{x}_n \\
      \vdots  \\
       \displaystyle \sum_{n=1}^{N} (y_{nK}-t_{nK})\mathbf{x}_n
    \end{array}
  \right) \tag{1}

で、ヘッセ行列は


\displaystyle \mathbf{H}_{ij} =  \mathbf{X}^{T} \mathbf{R}_{ij} \mathbf{X} \tag{2}

として、


\displaystyle \mathbf{H} = \left(
    \begin{array}{cccc}
      \mathbf{H}_{11} & \mathbf{H}_{12} & \ldots & \mathbf{H}_{1K} \\
      \mathbf{H}_{21} & \mathbf{H}_{22} & \ldots & \mathbf{H}_{2K} \\
      \vdots & \vdots & \ddots & \vdots \\
      \mathbf{H}_{K1} & \mathbf{H}_{K2} & \ldots & \mathbf{H}_{KK}
    \end{array}
  \right) \tag{3}

でした。

多変数のニュートン法の更新式は


\mathbf{x} = \mathbf{x}^{\prime} - \bar{\mathbf{H}}^{-1} \nabla \bar{f} \tag{4}

で、この\nabla fが式(1)、\mathbf{H}が式(3)です。(参考:ニュートン法(多変数の場合))バーがついているのは、 \mathbf{x}^{\prime}における値という意味です。

式(4)の更新式を使って解いてみた結果です。 f:id:opabinia2:20180908003126p:plain ちょっと微妙。ニュートン法のイテレーションを増やすとソフトマックス関数がオーバーフローしてしまい、その前で止めたので十分な識別になっていません。

思いつきでヘッセ行列の逆行列を、擬似逆行列にしてみたらオーバーフローすることなくきちんと識別できました。 f:id:opabinia2:20180908003034p:plain なんで擬似逆行列だとオーバーフローしないのかな。ヘッセ行列の逆行列が正しく計算できなかった、というのが根本の問題なのかな?うーん。

# 多クラスロジスティック回帰
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm


# 判別結果
def f(x0, x1):
    return np.argmax(np.dot(w.T, np.array([x0, x1, 1.0])))


# 交差エントロピー誤差
def Ew(y, T):
    ew = 0
    for n in range(N):
        for k in range(K):
            ew -= T[n, k] * np.log(y[n, k])
    return ew


# 教師データの属するクラスを返す
def teaching(x, y):
    if y > x + 0.5:
        return 0
    elif y > x - 0.5:
        return 1
    return 2


# ソフトマックス関数
def softmax(a, k):
    return np.exp(a[k])/np.sum(np.exp(a))


# 重み付け行列Rを計算
def R_(y, i, j):
    I = np.eye(N)
    R_ij = np.empty(N)
    for k in range(N):
        R_ij[k] = y[k, i] * (I[i, j] - y[k, j])
    return np.diag(R_ij)


# ヘッセ行列を計算
def H_(y):
    H = np.empty([M * K, M * K])
    for i in range(K):
        for j in range(K):
            # 重み付け行列R
            R = R_(y, i, j)
            H[M * j:M * j + M, M * i:M * i + M] = np.dot(np.dot(X.T, R), X)
    return H


# ランダムシードを固定
np.random.seed(0)

# 訓練データ数
N = 90

# 入力次元数(ダミー入力を含めて)
M = 3

# クラス数
K = 3

# 教師データ作成
x = np.random.uniform(0, 1, [N, M-1])
t = np.vectorize(teaching)(x[:, 0], x[:, 1])

# 行列X(ダミー入力を加える)
X = np.hstack([x,  np.ones([N, 1])])

# 1 of K符号
T = np.eye(K)[t]

# パラメータw初期化
w = np.zeros([M, K])

# ベクトルy
y = np.zeros([N, K])
a = np.dot(X, w)
for i in range(N):
    y[i, :] = np.array([softmax(a[i, :], k) for k in range(K)])

# 勾配ベクトルを計算
dw = np.zeros([M*K])
for k in range(K):
    dw[k * M:k * M + M] = np.dot(X.T, y[:, k] - T[:, k])

# 最急降下法
# w = w.T.flatten()
# for it in range(100):
#  # 交差エントロピー誤差
#  print(Ew(y, T))
#  w -= 0.1*dw
#  a = np.dot(X, w.reshape([K, M]).T)
#  for i in range(N):
#      y[i, :] = np.array([softmax(a[i, :], k) for k in range(K)])
#  for k in range(K):
#      dw[k * M:k * M + M] = np.dot(X.T, y[:, k] - T[:, k])
# w = w.reshape([M, K]).T


# ヘッセ行列
H = H_(y)

# ニュートン法
w = w.T.flatten()
for ite in range(10):
    print(Ew(y, T))
    # w -= np.dot(np.linalg.inv(H), dw)
    w -= np.dot(np.linalg.pinv(H), dw)
    # yを再計算
    a = np.dot(X, w.reshape([K, M]).T)
    for i in range(N):
        y[i, :] = np.array([softmax(a[i, :], k) for k in range(K)])
    # 勾配を再計算
    for k in range(K):
        dw[k * M:k * M + M] = np.dot(X.T, y[:, k] - T[:, k])
    # ヘッセ行列を再計算
    H = H_(y)
w = w.reshape([M, K]).T

# グラフ表示用の判別結果
a, b = np.meshgrid(np.linspace(0, 1, 1000), np.linspace(0, 1, 1000))
vec_f = np.vectorize(f)
c = vec_f(a, b)

plt.xlim(0, 1)
plt.ylim(0, 1)
plt.scatter(x[t == 0, 0], x[t == 0, 1], color="blue", alpha=0.5)
plt.scatter(x[t == 1, 0], x[t == 1, 1], color="green", alpha=0.5)
plt.scatter(x[t == 2, 0], x[t == 2, 1], color="red", alpha=0.5)
plt.contourf(a, b, c, alpha=0.2, cmap=cm.coolwarm)
plt.show()