はじめに
論文やスライドで、畳み込みニューラルネットワークのアーキテクチャを良い感じに表示したいときがありますよね?スライドだとオリジナル論文の図の引用でも良いかなという気がしますが、論文の図としては使いたくありません。
ということでKerasのSequentialモデルのような記法でモデルを定義すると、そのアーキテクチャを良い感じに図示してくれるツールを作りました。言ってしまえばテキストを出力しているだけのツールなので依存ライブラリとかもありません。
https://github.com/yu4u/convnet-drawer
ここまで実装するつもりはなかったので綺麗に設計できていませんが、バグ報告や追加機能要望welcomeです!
経緯
元々は、全然違う論文用の図の作成方法を検討していたところ、@gou_koutaki 先生からSVGで出力すればいいじゃんとコメントを頂きました。
Pythonにはsvgwriteというライブラリがあり、SVGが簡単に出力できそうだったのですが、このくらいであれば自分で書いちゃえるかなと色々遊んでいたところ、気づいたら畳み込みニューラルネットワークの図示ツールを作成していました…
最初はただのBOXを描いていたのが、気づけば畳み込み層を作って、プーリング層も実装したら全結合層も欲しくなり、実際のモデルに合わせるためにpaddingも対応していました…何をしているんだろう…
ちなみに、畳み込みニューラルネットワークを図示するような素晴らしいツールとして下記があるのですが、モデルの書き方が直感的ではないのと、個人的には特徴マップをボリュームで表現したかったので利用したことはありませんでした。
https://github.com/gwding/draw_convnet
利用方法
KerasのSequentialモデルの記法のように、畳み込み層やプーリング層を重ねていくだけです。例えばAlexNetは下記のようにモデルを記述します。
from convnet_drawer import Model, Conv2D, MaxPooling2D, Flatten, Dense
model = Model(input_shape=(227, 227, 3))
model.add(Conv2D(96, (11, 11), (4, 4)))
model.add(MaxPooling2D((3, 3), strides=(2, 2)))
model.add(Conv2D(256, (5, 5), padding="same"))
model.add(MaxPooling2D((3, 3), strides=(2, 2)))
model.add(Conv2D(384, (3, 3), padding="same"))
model.add(Conv2D(384, (3, 3), padding="same"))
model.add(Conv2D(256, (3, 3), padding="same"))
model.add(MaxPooling2D((3, 3), strides=(2, 2)))
model.add(Flatten())
model.add(Dense(4096))
model.add(Dense(4096))
model.add(Dense(1000))
model.save_fig("example.svg")
上記のスクリプトを実行すると、下記のようなSVGフォーマットの画像が出力されます。ベクタ画像なので拡大しても綺麗です。PowerPoint経由で論文用のepsやpdfファイルにも変換できると思います(未確認)。
対応レイヤ
一般的なレイヤしか対応していません。Deconvや(Sequentialなので当然ですが)ResNetのskip connectionのようなものはありません。そもそも最近のモデルは深すぎて図示しても何が何やらになりそうですね。
Conv2D
Conv2D(filters, kernel_size, strides=(1, 1), padding="valid")
畳み込み層です。filters
にフィルタ数を、kernel_size
にフィルタのカーネルサイズ(タプル)を、strides
にストライドのサイズ(タプル)を入力します。padding
は"valid"
か"same"
のみ対応しています(実装上は"same"
でなければ"valid"
になっちゃいます)。
(例)Conv2D(96, (11, 11), (4, 4)))
MaxPooling2D, AveragePooling2D
MaxPooling2D(pool_size=(2, 2), strides=None, padding="valid")
プーリング層です。pool_size
にプーリングのカーネルサイズ(タプル)を、strides
にストライドのサイズ(タプル)を、padding
にパディングのタイプを入力します。strides
を入力しない場合、pool_size
と同じ値がセットされます。
(例)MaxPooling2D((3, 3), strides=(2, 2))
GlobalAveragePooling2D
GlobalAveragePooling2D()
Global average poolingです。特徴マップのサイズを1×1にし、flatten(1次元に)します。この後は全結合層のみが追加できます。
Flatten
Flatten()
特徴マップをflattenし、1次元にします。この後は全結合層のみが追加できます。
Dense
Dense(units)
全結合層です。units
に出力次元数を入力します。
(例)Dense(4096)
実行例
LeNet