畳み込みニューラルネットワークをKeras風に定義するとアーキテクチャを図示してくれるツールを作った

はじめに

論文やスライドで、畳み込みニューラルネットワークのアーキテクチャを良い感じに表示したいときがありますよね?スライドだとオリジナル論文の図の引用でも良いかなという気がしますが、論文の図としては使いたくありません。
ということでKerasのSequentialモデルのような記法でモデルを定義すると、そのアーキテクチャを良い感じに図示してくれるツールを作りました。言ってしまえばテキストを出力しているだけのツールなので依存ライブラリとかもありません。
https://github.com/yu4u/convnet-drawer

ここまで実装するつもりはなかったので綺麗に設計できていませんが、バグ報告や追加機能要望welcomeです!

経緯

元々は、全然違う論文用の図の作成方法を検討していたところ、@gou_koutaki 先生からSVGで出力すればいいじゃんとコメントを頂きました。

Pythonにはsvgwriteというライブラリがあり、SVGが簡単に出力できそうだったのですが、このくらいであれば自分で書いちゃえるかなと色々遊んでいたところ、気づいたら畳み込みニューラルネットワークの図示ツールを作成していました…
最初はただのBOXを描いていたのが、気づけば畳み込み層を作って、プーリング層も実装したら全結合層も欲しくなり、実際のモデルに合わせるためにpaddingも対応していました…何をしているんだろう…

ちなみに、畳み込みニューラルネットワークを図示するような素晴らしいツールとして下記があるのですが、モデルの書き方が直感的ではないのと、個人的には特徴マップをボリュームで表現したかったので利用したことはありませんでした。
https://github.com/gwding/draw_convnet

利用方法

KerasのSequentialモデルの記法のように、畳み込み層やプーリング層を重ねていくだけです。例えばAlexNetは下記のようにモデルを記述します。

from convnet_drawer import Model, Conv2D, MaxPooling2D, Flatten, Dense

model = Model(input_shape=(227, 227, 3))
model.add(Conv2D(96, (11, 11), (4, 4)))
model.add(MaxPooling2D((3, 3), strides=(2, 2)))
model.add(Conv2D(256, (5, 5), padding="same"))
model.add(MaxPooling2D((3, 3), strides=(2, 2)))
model.add(Conv2D(384, (3, 3), padding="same"))
model.add(Conv2D(384, (3, 3), padding="same"))
model.add(Conv2D(256, (3, 3), padding="same"))
model.add(MaxPooling2D((3, 3), strides=(2, 2)))
model.add(Flatten())
model.add(Dense(4096))
model.add(Dense(4096))
model.add(Dense(1000))
model.save_fig("example.svg")

上記のスクリプトを実行すると、下記のようなSVGフォーマットの画像が出力されます。ベクタ画像なので拡大しても綺麗です。PowerPoint経由で論文用のepsやpdfファイルにも変換できると思います(未確認)。

image.png

対応レイヤ

一般的なレイヤしか対応していません。Deconvや(Sequentialなので当然ですが)ResNetのskip connectionのようなものはありません。そもそも最近のモデルは深すぎて図示しても何が何やらになりそうですね。

Conv2D

Conv2D(filters, kernel_size, strides=(1, 1), padding="valid")
畳み込み層です。filtersにフィルタ数を、kernel_sizeにフィルタのカーネルサイズ(タプル)を、stridesにストライドのサイズ(タプル)を入力します。padding"valid""same"のみ対応しています(実装上は"same"でなければ"valid"になっちゃいます)。
(例)Conv2D(96, (11, 11), (4, 4)))

MaxPooling2D, AveragePooling2D

MaxPooling2D(pool_size=(2, 2), strides=None, padding="valid")
プーリング層です。pool_sizeにプーリングのカーネルサイズ(タプル)を、stridesにストライドのサイズ(タプル)を、paddingにパディングのタイプを入力します。stridesを入力しない場合、pool_sizeと同じ値がセットされます。
(例)MaxPooling2D((3, 3), strides=(2, 2))

GlobalAveragePooling2D

GlobalAveragePooling2D()
Global average poolingです。特徴マップのサイズを1×1にし、flatten(1次元に)します。この後は全結合層のみが追加できます。

Flatten

Flatten()
特徴マップをflattenし、1次元にします。この後は全結合層のみが追加できます。

Dense

Dense(units)
全結合層です。unitsに出力次元数を入力します。
(例)Dense(4096)

実行例

LeNet

image.png

AlexNet
image.png

ZFNet
image.png

VGG16
image.png