機械学習に詳しくなりたいブログ

機械学習や数学について勉強した内容を中心に書きます。100%趣味です。記事は数学的に厳密でなかったり誤りを含んでいるかもしれません。ご指摘頂ければ幸いです。

MNISTをKL展開して部分空間を図示

MNISTとは

MNISTとは、http://mldata.org/repository/data/viewslug/mnist-original/で公開されている下図のような手書き数字データセットです。機械学習の分類手法のベンチマーク(or 勉強用?)によく使われているようです。

MNISTをKL展開する

今回は、このMNISTの画像データをKL展開し、どのような部分空間が得られるのか見てみたいと思います。これまでは2~3次元のデータで実験してきましたが、画像に変わっても次元が増えるだけで計算は全く変わりません。計算するには画像をベクトルで表現する必要がありますが、以下のような値を持つ2x2ピクセルの画像があったとすれば、単純にそれを列ベクトルにまとめればよいだけです。



MNISTの画像データは28x28のサイズですので、784次元のデータになります。KL展開はすでにKL展開 分散最大基準で実験済みですから、データをMNISTに差し替えるだけです。今回は数字の7をKL展開してみました。7はデータセットの中に7293個含まれていました。得られた部分空間を図示してみます。もちろん、KL展開で得られるベクトルが0~255の整数をとっているわけではなく、正規化して図示しています。


グラフタイトルの数字は、何番目に大きい固有値に対応するものか?を表しています。固有値の大きい軸は数字の7の特徴をよく捉えており、固有値が小さくなるにしたがって特徴が薄くなっていき、最後はほとんど何も表さない軸になっています。ビジュアルで確認できると面白いですね!

さて、これによって得られた784個のベクトルは、784次元空間の基底となっています。7の数字をよく近似してはいますが、任意の28x28の画像はこの線形結合によって表されます。以下のようなイメージ。

数字の7であれば、上位の基底を使うだけで近似でき、それ以外の画像だと下位の基底まで使わないと近似できないはずです。次回この辺を実験してみたいと思います。→KL展開で得られた基底で画像を表現する

今回のコードです。MNISTの.h5ファイルの使い方がわからなくて悩んだ。結局HDFViewというソフトでデータ構造を確認してようやく使えた。

# MNISTをKL展開し、部分空間を図示
import h5py
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm


mnist = h5py.File("./mnist/mnist-original.h5", "r")
data = np.array(mnist["data"].get("data").value)
label = np.array(mnist["data"].get("label").value)

# 削減次元数
D2 = 30

# 7のデータだけ取る
x = data[:, np.where(label==7)[0]]

# データ次元数
D = 784

# データ数
N = x.shape[1]

# 平均ベクトルを求める
m = np.array([np.average(x[i, :]) for i in range(D)])

# 共分散行列を求める
s = np.zeros([D, D])
for i in range(N):
    s += np.outer(x[:, i] - m, x[:, i] - m)
s = s/N


# 固有値と固有ベクトルを求める
lam, v = np.linalg.eigh(s)

# 固有値の降順に固有ベクトルを並べ替える
v = v[:, np.argsort(lam)[::-1]]
# D2次元の部分空間の基底が得られる
w = v[:, 0:D2]

# np.save("./data/7", v)

for i in range(5):
    plt.subplot(3, 5, i+1).set_aspect('equal')
    plt.title(i+1)
    plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False)
    plt.imshow(v[:,i].reshape(28, 28), interpolation="None", cmap=cm.gray)

for i in range(5):
    plt.subplot(3, 5, i+6).set_aspect('equal')
    plt.title(i+100)
    plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False)
    plt.imshow(v[:,i+99].reshape(28, 28), interpolation="None", cmap=cm.gray)

for i in range(5):
    plt.subplot(3, 5, i+11).set_aspect('equal')
    plt.title(i+780)
    plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False)
    plt.imshow(v[:,i+779].reshape(28, 28), interpolation="None", cmap=cm.gray)

plt.show()