機械学習に詳しくなりたいブログ

機械学習や数学について勉強した内容を中心に書きます。100%趣味です。記事は数学的に厳密でなかったり誤りを含んでいるかもしれません。ご指摘頂ければ幸いです。

重回帰分析を最小二乗法で解く

変数が1つの回帰分析は単回帰と呼ばれ(参考:線形回帰を最小二乗法で解く)、複数の変数を使うものは重回帰分析と呼ばれます。今回は重回帰分析を最小二乗法で解いてみたいと思います。と言っても、重回帰分析でも係数が線形結合されていれば、その解は最小二乗法の解の導出で導いたものと全く同じになります。 最小二乗法の解の導出で求めた方程式


\mathbf{w} = (\mathbf{X}^{T}\mathbf{X})^{-1} \mathbf{X}^{T}\mathbf{t} \tag{1}

 \mathbf{X}は、


  \mathbf{X} 
= \left(
    \begin{array}{cccc}
      x_{0}^{0} & x_{0}^{1} & \ldots & x_{0}^{M-1} \\
      x_{1}^{0} & x_{1}^{1} & \ldots & x_{1}^{M-1} \\
      \vdots & \vdots & \ddots & \vdots \\
      x_{N-1}^{0} & x_{N-1}^{1} & \ldots & x_{N-1}^{M-1}
    \end{array}
  \right)
 \tag{2}

としましたが、これは


  \mathbf{X} 
= \left(
    \begin{array}{cccc}
      \phi_{0}(x_{0}) & \phi_{1}(x_{0}) & \ldots & \phi_{M-1}(x_{0}) \\
      \phi_{0}(x_{1}) & \phi_{1}(x_{1}) & \ldots & \phi_{M-1}(x_{1}) \\
      \vdots & \vdots & \ddots & \vdots \\
      \phi_{0}(x_{N-1}) & \phi_{1}(x_{N-1}) & \ldots & \phi_{M-1}(x_{N-1})
    \end{array}
  \right)
 \tag{3}

と一般化できます。 \phi_i(x)=x^{i}としたときが多項式近似に相当します。重回帰分析の場合は、ベクトル \mathbf xを引数に持つ関数 \phi(\mathbf x)とするだけで良いようです。係数 \mathbf wで微分すれば、引数が何種類あろうが \mathbf{X} \mathbf wに依存しないので、結局解は同じになります。

よって、例えば2変数の重回帰分析を考えたとき、 f(x,y) = w_0 x + w_1 y + w_2 xy + w_3 で近似したいなら、 \phi_0(\mathbf x)=x \phi_1(\mathbf x)=y \phi_2(\mathbf x)=xy \phi_3(\mathbf x)=1として、つまり、


  \mathbf{X} 
= \left(
    \begin{array}{cccc}
      x_{0} & y_{0} & x_{0}y_{0} & 1 \\
      x_{1} & y_{1} & x_{1}y_{1} & 1 \\
      \vdots & \vdots & \vdots & \vdots \\
       x_{N-1} & y_{N-1} & x_{N-1}y_{N-1} & 1 
    \end{array}
  \right)
 \tag{4}

として式(1)を解けばよいことになります。

試しにやってみました。何変数あっても良いのですが、グラフで確認しやすいよう2変数にします。モデルは


z = w_0 x +w_1 y + w_2 xy + w_3 x^2 + w_4 y^2 + w_5 
 \tag{5}

として、以下のような空間上の点を近似してみます。

点だけだと真値がわかりづらいですが、これは以下のような青色の曲面にノイズを加えたものです。

そして近似結果。

青が真値で赤が近似曲面です。ちゃんと近似できてますね!単回帰だと見た目でなんとなく近似後の曲線がイメージできますが、訓練データだけではイメージしづらい曲面が近似できるとなんか感動しちゃいますね。

重回帰分析って、確か交互作用とかややこしい話があるんですよね。深入りすると機械学習から離れていきそうなので、とりあえずこの辺にしておきます。

今回のコードです。

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# 真値の曲面
def z(x, y):
    return x**2 + y**2 + 5*x*y


vec_z = np.vectorize(z)

# ランダムシードを固定
np.random.seed(0)

# 訓練データ数
N = 10

# 訓練データ
x_t = np.random.uniform(-1, 1, N)
y_t = np.random.uniform(-1, 1, N)
t = (vec_z(x_t, y_t) + np.random.normal(0, 0.2, N)).reshape(N, 1)

# 行列Xを作成
X = np.array([x_t, y_t, x_t*y_t, x_t**2, y_t**2, np.ones(N)]).T

# 係数wを求める
w = np.linalg.solve(np.dot(X.T, X), np.dot(X.T, t))

# 真値
x2, y2 = np.meshgrid(np.linspace(-1, 1, 30), np.linspace(-1, 1, 30))
z2 = vec_z(x2, y2)

# 予測値
z3 = w[0]*x2 + w[1]*y2 + w[2]*x2*y2 + w[3]*x2**2 + w[4]*y2**2 + w[5]

# グラフの表示
fig = plt.figure()
ax = Axes3D(fig)
ax.scatter(x_t, y_t, t.T)
ax.plot_wireframe(x2, y2, z2)
ax.plot_wireframe(x2, y2, z3, color="red")
plt.xlim(-1, 1)
plt.ylim(-1, 1)
plt.show()