今回やってみること
MNISTをKL展開して部分空間を図示で、MNISTの数字の7をKL展開し、その部分空間を見てみました。これらは28x28=784次元の基底となっているはずですから、任意の画像を線形結合で表すことができるはずです。今回はいくつかの画像がこの基底によって表現される様子を確認したいと思います。
MNISTの7をKL展開して得られた固有ベクトル
再掲ですが、7の固有ベクトルと、線形結合のイメージ図です。
7の固有ベクトル
線形結合のイメージ
展開された基底で元の画像を復元する
実際に見ていきます。まずは数字の7。
固有ベクトルの大きいものから順に足していっています。タイトルの"dimension"は何次元までを使って表現しているか、を表します。7をKL展開しただけあって、784次元のうち、最初の数次元程度でほとんど元の形が表現できていることがわかります。ということは、手書きの7を表すだけなら数次元あれば十分ということなんですね。
次に数字の6。
こちらは、元の形が復元されるスピードが7と比べて遅いことがわかります。だいたい50次元くらいでようやく7ではなく6であることが見えてきています。ただ、数字は違えど画像自体は比較的似通っていますので、100次元も使えば元の形が十分表現できています。
最後は数字ではなくイラスト。
数字の7とは遠すぎて、なかなか元の画像が見えてきません。が、やはり784次元全て使えば完全に元の画像を表現できました。当たり前のことを確認をしただけですが、想像していた通りの結果になると嬉しいですね。
さて、これを分類手法に使ったものを部分空間法と呼ぶようです。今回は7の固有ベクトルしか見ていませんが、分類したい画像それぞれの固有ベクトルを計算しておき、どの固有ベクトルを使えば対象の画像がよく表現されるか?を判断指標に分類します。次回、部分空間法の1つであるCLAFIC法という手法でMNISTを識別したとき、どのくらいの精度になるのか見てみたいと思います。
今回のコードです
# 画像を7の固有ベクトルを基底とした空間で表現 import numpy as np import matplotlib.pyplot as plt from matplotlib import cm import matplotlib.animation as ani from PIL import Image def func(frame, img): # 固有値の降順に足していく img += sub[frame]*v[:, frame] plt.title("dimension:{0}".format(frame)) im.set_data(img.reshape(28, 28)) # MNISTの7の固有ベクトルを読み込み。固有値の降順にしてある。 # https://www.iwanttobeacat.com/entry/2018/06/26/224446で求めたもの v = np.load("./data/7.npy") # オリジナル画像 orig = np.array(Image.open("./data/mario.bmp").convert('L')) orig = orig.flatten() # オリジナル画像を表示 fig = plt.figure(figsize=(4, 3)) plt.imshow(orig.reshape(28, 28), vmin = 0, vmax = 255, interpolation="None", cmap=cm.gray) plt.axes().set_aspect('equal') plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False) plt.show() # 基底変換 sub = np.dot(v.T, orig) # アニメーションで表示 fig = plt.figure(figsize=(4, 3)) plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False) im = plt.imshow(np.zeros([28, 28]), vmin = 0, vmax = 255, interpolation="None", cmap=cm.gray) img = np.zeros(784) anim = ani.FuncAnimation(fig, func, fargs=(img,), interval=1, repeat=True, frames=784) anim.save('anim_mario.gif', writer="imagemagick") plt.show()