先回、ニューラルネットワークで回帰を解くで回帰を解きましたので、今回は3クラスの分類問題をニューラルネットワークで解いてみます。出力の活性化関数をソフトマックスにするだけで、実装は前回と同じです。隠れ層は前回の3つのままだとうまく分類できなかったので6つに増やしました。今回も最急降下法なので、もしかしたらもっと効率の良い最適化手法なら隠れ層3つでも何とかなったのかもしれません。ニューラルネットワークの図は描くのが大変なので省略します、、、。
ということで早速結果です。
ちゃんと分類できていますね。
今回のコードです。相変わらずダサい書き方で恥ずかしいのですが。
# ニューラルネットワーク 分類 import numpy as np import matplotlib.pyplot as plt from matplotlib import cm # ソフトマックス関数 def softmax(a): # 引数のaは、データを列に並べたもの。よって、列サイズはデータ数、行数はクラス数 ret = np.empty([K, a.shape[1]]) for i in range(a.shape[1]): ret[:, i] =np.array([np.exp(a[k,i])/np.sum(np.exp(a[:,i])) for k in range(K)]) return ret # 順伝播 def forward(x): # 隠れ層 a1 = w1 @ x # 活性化関数 h1 = np.tanh(a1) # ダミー入力を加える h1 = np.vstack([h1, np.ones(x.shape[1])]) # ソフトマックスの出力とa1を返す return softmax(w2 @ h1), a1 # 誤差逆伝播 def backprop(y, a1): # 出力層の誤差 delta2 = y - t # 隠れ層の誤差 delta1 = (1 - np.tanh(a1)**2) * (w2.T @ (y-t)) return delta1, delta2 # ランダムシードを固定 np.random.seed(0) # 入力次元数(ダミーを含まない) D = 2 # 隠れ層の数 M = 6 # 出力層 K = 3 # 重みパラメータ(ダミー入力分+1している) w1 = np.random.uniform(-1, 1, [M, D+1]) w2 = np.random.uniform(-1, 1, [K, M+1]) # データ数 N0 = 100 N = N0*K # 学習係数 ALPHA = 0.01 # 訓練データ x1 = np.linspace(0, 0.5, N0) x2 = np.linspace(0.25, 0.75, N0) x3 = np.linspace(0.5, 1.0, N0) y1 = np.sin(2*np.pi*x1) + np.random.normal(0, 0.2, N0) y2 = -1*np.sin(2*np.pi*(x2-0.25)) + np.random.normal(0, 0.2, N0) y3 = np.sin(2*np.pi*(x3-0.5)) + np.random.normal(0, 0.2, N0) x = np.vstack([np.hstack([x1, x2, x3]), np.hstack([y1, y2, y3])]) t = np.empty([K, N]) # ダミー入力を加える x = np.vstack([x, np.ones(x.shape[1])]) for i in range(N): t[:, i] = np.eye(K)[int(i//N0)] for i in range(30000): # 順伝播 y, a1 = forward(x) # ユニットの重みを求める delta1, delta2 = backprop(y, np.vstack([a1, np.ones(N)])) # 微分係数を求める dw2 = delta2 @ np.vstack([np.tanh(a1), np.ones(N)]).T dw1 = delta1[0:M, :] @ x.T # 更新 w2 -= ALPHA*dw2 w1 -= ALPHA*dw1 a_x0, a_x1 = np.meshgrid(np.linspace(0, 1, 300), np.linspace(-1.5, 1.5, 300)) a_x0 = a_x0.flatten() a_x1 = a_x1.flatten() ans, dummy = forward(np.vstack([a_x0, a_x1, np.ones(300*300)])) c = np.argmax(ans,axis=0) c = c.reshape([300,300]) a_x0 = a_x0.reshape([300,300]) a_x1 = a_x1.reshape([300,300]) plt.xlim(0,1) plt.ylim(-1.5,1.5) plt.contourf(a_x0, a_x1, c, alpha=0.2, cmap=cm.coolwarm) plt.scatter(x1,y1) plt.scatter(x2,y2) plt.scatter(x3,y3) plt.show()