機械学習に詳しくなりたいブログ

機械学習や数学について勉強した内容を中心に書きます。100%趣味です。記事は数学的に厳密でなかったり誤りを含んでいるかもしれません。ご指摘頂ければ幸いです。

KL展開で得られた基底で画像を表現する

今回やってみること

MNISTをKL展開して部分空間を図示で、MNISTの数字の7をKL展開し、その部分空間を見てみました。これらは28x28=784次元の基底となっているはずですから、任意の画像を線形結合で表すことができるはずです。今回はいくつかの画像がこの基底によって表現される様子を確認したいと思います。

MNISTの7をKL展開して得られた固有ベクトル

再掲ですが、7の固有ベクトルと、線形結合のイメージ図です。

7の固有ベクトル


線形結合のイメージ

展開された基底で元の画像を復元する

実際に見ていきます。まずは数字の7。

固有ベクトルの大きいものから順に足していっています。タイトルの"dimension"は何次元までを使って表現しているか、を表します。7をKL展開しただけあって、784次元のうち、最初の数次元程度でほとんど元の形が表現できていることがわかります。ということは、手書きの7を表すだけなら数次元あれば十分ということなんですね。


次に数字の6。

こちらは、元の形が復元されるスピードが7と比べて遅いことがわかります。だいたい50次元くらいでようやく7ではなく6であることが見えてきています。ただ、数字は違えど画像自体は比較的似通っていますので、100次元も使えば元の形が十分表現できています。


最後は数字ではなくイラスト。

数字の7とは遠すぎて、なかなか元の画像が見えてきません。が、やはり784次元全て使えば完全に元の画像を表現できました。当たり前のことを確認をしただけですが、想像していた通りの結果になると嬉しいですね。

さて、これを分類手法に使ったものを部分空間法と呼ぶようです。今回は7の固有ベクトルしか見ていませんが、分類したい画像それぞれの固有ベクトルを計算しておき、どの固有ベクトルを使えば対象の画像がよく表現されるか?を判断指標に分類します。次回、部分空間法の1つであるCLAFIC法という手法でMNISTを識別したとき、どのくらいの精度になるのか見てみたいと思います。

今回のコードです

# 画像を7の固有ベクトルを基底とした空間で表現

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.animation as ani
from PIL import Image


def func(frame, img):
    # 固有値の降順に足していく
    img += sub[frame]*v[:, frame]
    plt.title("dimension:{0}".format(frame))
    im.set_data(img.reshape(28, 28))


# MNISTの7の固有ベクトルを読み込み。固有値の降順にしてある。
# https://www.iwanttobeacat.com/entry/2018/06/26/224446で求めたもの
v = np.load("./data/7.npy")

# オリジナル画像
orig = np.array(Image.open("./data/mario.bmp").convert('L'))
orig = orig.flatten()

# オリジナル画像を表示
fig = plt.figure(figsize=(4, 3))
plt.imshow(orig.reshape(28, 28), vmin = 0, vmax = 255, interpolation="None", cmap=cm.gray)
plt.axes().set_aspect('equal')
plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False)
plt.show()

# 基底変換
sub = np.dot(v.T, orig)

# アニメーションで表示
fig = plt.figure(figsize=(4, 3))
plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False)
im = plt.imshow(np.zeros([28, 28]), vmin = 0, vmax = 255, interpolation="None", cmap=cm.gray)
img = np.zeros(784)
anim = ani.FuncAnimation(fig, func, fargs=(img,), interval=1, repeat=True, frames=784)
anim.save('anim_mario.gif', writer="imagemagick")
plt.show()