機械学習に詳しくなりたいブログ

機械学習や数学について勉強した内容を中心に書きます。100%趣味です。記事は数学的に厳密でなかったり誤りを含んでいるかもしれません。ご指摘頂ければ幸いです。

ニューラルネットワークで回帰を解く

先回、ニューラルネットワークモデルの勾配を求める誤差逆伝播法の手順を確認しましたので、実際にこれを使って回帰を解いてみたいと思います。今回は最急降下法を使いました。モデルは下図のようなものとしました。活性化関数h\tanhとし、出力の活性化関数はなし(恒等関数)としました。なぜこうしたかと聞かれると、参考書に例として載っていたからとしか答えられませんが、、、。\tanhはシグモイド関数にしても同じようなことだと思いますが、それぞれ何か使い勝手が違うんでしょうか。(シグモイド関数:確率的生成モデル(1)

f:id:opabinia2:20190119144247p:plain

早速結果です。青色の点が訓練データで、赤い線がニューラルネットワークの学習によって得られた出力です。概ね学習データを近似できています。不連続な部分などは少し誤差が目立ちますが、繰り返し回数をさらに10倍に増やしたところで、結果はそんなに変わりませんでした。最急降下法ではなくもっと効率のよいアルゴリズムを使えば結果は改善されるようです。 f:id:opabinia2:20190119145539p:plain

線形回帰を最小二乗法で解くなどで扱っていた、ノイズが乗ったようなデータでももちろん近似曲線を求めることができます。 f:id:opabinia2:20190120220742p:plain 右のグラフは隠れ層を増やしてみた場合。やっぱり過学習が起きてしまうんですね。

今回のコードです。wosugiさんのコメントを参考に、np.dotではなく@演算子を使ってみました。こっちのほうがすっきりしていいですね。ありがとうございました。

# ニューラルネットワーク 回帰
import numpy as np
import matplotlib.pyplot as plt


# 対象の関数
def f(x):
    if x < 0:
        return -1
    else:
        return  1
    # return x**2
    # return np.sin(np.pi*x)
    # if x < 0:
    #  return -x
    # else:
    #  return  x

# 順伝播
def forward(x):
    # 隠れ層
    a1 = w1 @ x
    # 活性化関数
    h1 = np.tanh(a1)
    # ダミー入力を加える
    h1 = np.vstack([h1, np.ones(N)])
    # 出力と隠れ層の値を返す
    return w2 @ h1, a1


# 誤差逆伝播
def backprop(y, a1):
    # 出力層の誤差δ
    delta2 = y - t
    # 隠れ層の誤差δ
    delta1 = (1 - np.tanh(a1)**2) * (w2.T @ (y-t))
    return delta1, delta2


# 入力次元数(ダミーを含まない)
D = 1

# 隠れ層の数
M = 3

# 出力層
K = 1

# 重みパラメータ(ダミー入力分+1している)
w1 = np.random.uniform(-1, 1, [M, D+1])
w2 = np.random.uniform(-1, 1, [K, M+1])

# データ数
N = 50

# 学習係数
ALPHA = 0.01

# 訓練データ
x = np.linspace(-1, 1, N)
t = np.vectorize(f)(x)
x = np.vstack([x, np.ones(N)])

for i in range(30000):
    # 順伝播
    y, a1 = forward(x)
    # ユニットの誤差δを求める
    delta1, delta2 = backprop(y, np.vstack([a1, np.ones(N)]))
    # 偏微分係数を求める。
    dw2 = delta2 @ np.vstack([np.tanh(a1), np.ones(N)]).T
    dw1 = delta1[0:M, :]@x.T

    # 最急降下法による更新
    w2 -= ALPHA*dw2
    w1 -= ALPHA*dw1

plt.plot(np.linspace(-1, 1, N),y[0],c="r")
plt.scatter(np.linspace(-1, 1, N),t,alpha=0.5)
plt.show()