読者です 読者をやめる 読者になる 読者になる

人工知能に関する断創録

人工知能、認知科学、心理学、ロボティクス、生物学、ゲームAIなどに興味を持っています。このブログでは人工知能のさまざまな分野について調査したことをまとめています。最近は、機械学習、複雑系(カオス、力学系)、Deep Learningなど。



Chainerによる畳み込みニューラルネットワークの実装

Deep Learning 機械学習 Chainer

Chainerによる多層パーセプトロンの実装(2015/10/5)のつづき。今回はChainerで畳み込みニューラルネットワーク(CNN:Convolutional Neural Network)を実装した。Theanoによる畳み込みニューラルネットワークの実装 (1)(2015/6/26)で書いたのと同じ構造をChainerで試しただけ。タスクは前回と同じくMNIST。

f:id:aidiary:20150626203849p:plain

今回は、MNISTデータの取得や訓練/テストの分割にscikit-learnの関数を使ってみた。

Chainerで畳み込みをするためには、訓練データの画像セットを(ミニバッチサイズ、チャンネル数、高さ、幅)の4次元テンソルに変換する必要があるここに書いてある)。今回はチャンネル数が1なので単純にreshapeで変形できる。

3チャンネルのカラー画像だとnumpyのtranspose()で4次元テンソルに変換できるみたい。transpose()は転置行列作るときに使うけどこのnumpyサンプルの3例目によるとndarrayの次元を入れ替えるときにも使えるようだ。あとで物体認識をやるときに確認しよう。

訓練時の誤差とテスト精度を描いてみると下のようになった。エポックが進むにつれて誤差が減り、学習が進んでいることがわかる。テスト精度は多少がたがたするが徐々に向上し、最大で99.3%くらいになる。今回はEarly-Stoppingのような高度な収束判定は使わず、単純に20エポック回しただけなので手を抜いている。GTX760で20エポックの学習に984秒かかった。

f:id:aidiary:20151007213443p:plain f:id:aidiary:20151007213449p:plain

学習したモデルはcPickleでファイルにダンプできる。このフォーラムの記事によると学習したモデルをファイルにダンプするときはmodel.to_cpu()でGPUからCPUに戻した方がよいとのこと。こうしておけばGPUがないマシンでも学習済みモデルを読み込める。

畳み込みニューラルネットは、学習対象の重みがフィルタに当たるので画像として描画できる。試しに学習したモデルの重みを可視化してみよう。下のようなコードで描ける。

#coding: utf-8
import cPickle
import matplotlib.pyplot as plt
model = cPickle.load(open("model.pkl", "rb"))

# 1つめのConvolution層の重みを可視化
print model.conv1.W.shape

n1, n2, h, w = model.conv1.W.shape
fig = plt.figure()
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
for i in range(n1):
    ax = fig.add_subplot(2, 10, i + 1, xticks=[], yticks=[])
    ax.imshow(model.conv1.W[i, 0], cmap=plt.cm.gray_r, interpolation='nearest')
plt.show()


# 2つめのConvolution層の重みを可視化
print model.conv2.W.shape
n1, n2, h, w = model.conv2.W.shape
fig = plt.figure()
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
pos = 0
for i in range(10):
    for j in range(10):
        ax = fig.add_subplot(10, 10, pos + 1, xticks=[], yticks=[])
        ax.imshow(model.conv2.W[i, j], cmap=plt.cm.gray_r, interpolation='nearest')
        pos += 1
plt.show()

f:id:aidiary:20151007220137p:plain f:id:aidiary:20151007220141p:plain

ちょっと解釈できない。ガボールフィルタみたいなのができるはずなのだけれど、ランダムだった初期状態からあまり変わらない気もする。もう少しエポック回せばよかったのかな?でも精度は十分上がったしなぁ。もっと別の例でも確認してみよう。

参考