皆さんこんにちは
お元気ですか。最近、Chainer便利でびっくりしたような頃合いです。
頻繁に更新することで有名なChainerですが、久々にupgradeすると以前よりも
シンプルなタスクについて、簡単に学習ができます。
Trainer
Chainer version 1.11.0よりTrainerと呼ばれる機能が実装されています。
以前まで学習用バッチ処理を自前で書くようなことが
必要でしたが、これを使うことによってバッチ処理を書く必要がなくなります。
実際の機能としてはある処理をhockしたり、グラフを出力したり
レポートを表示したりと学習中に確認したいグラフは沢山あります。
それらのグラフを可視化したいといったことは往々にしてあります。
Trainerの基本的な使い方
Trainerを使うと、Progress Barやlogを自動的に吐き出すことができます。
通常のモードでは、Trainerを基本的に使うことができます。
Extensionsを使うことにより、Trainerを使えます。
殆どExample通りですが、以下が最低限のコードとなります。
# coding:utf-8 from __future__ import absolute_import from __future__ import unicode_literals import chainer import chainer.datasets from chainer import training from chainer.training import extensions import chainer.links as L import chainer.functions as F class MLP(chainer.Chain): def __init__(self, n_units, n_out): super(MLP, self).__init__( l1=L.Linear(None, n_units), l2=L.Linear(None, n_units), l3=L.Linear(None, n_out), ) def __call__(self, x): h1 = F.relu(self.l1(x)) h2 = F.relu(self.l2(h1)) return self.l3(h2) train, test = chainer.datasets.get_mnist() train_iter = chainer.iterators.SerialIterator(train, 32) test_iter = chainer.iterators.SerialIterator(test, 32, repeat=False, shuffle=False) model = L.Classifier(MLP(784, 10)) optimizer = chainer.optimizers.SGD() optimizer.setup(model) updater = training.StandardUpdater(train_iter, optimizer, device=-1) trainer = training.Trainer(updater, (10, 'epoch'), out="result") trainer.extend(extensions.Evaluator(test_iter, model, device=10)) trainer.extend(extensions.dump_graph('main/loss')) trainer.extend(extensions.snapshot(), trigger=(10, 'epoch')) trainer.extend(extensions.LogReport()) trainer.extend(extensions.PrintReport( ['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy'])) trainer.extend(extensions.ProgressBar()) trainer.run()
Trainerにはupdate方法を宣言します。
Trainer#extendを使うことで、一定の条件の元で起動します。
Extension | 概要 |
---|---|
Evaluator | 一定の期間で評価する。(validation) |
dump_graph | グラフを表示する。 |
snapshot | 一定の間隔(ユーザ指定)でモデルを保存する |
LogReport | ログとして出力する。 |
PrintReport | print文を使って現状をprintする。(以下に例あり) |
ProgressBar | Progress Barを表示する。 |
上記の場合の出力例は次のとおりです。
epoch main/loss validation/main/loss main/accuracy validation/main/accuracy 1 0.624464 0.306581 0.850083 0.913538 2 0.282575 0.240019 0.919283 0.932608 total [##########........................................] 21.87% this epoch [#########.........................................] 18.67% 4100 iter, 2 epoch / 10 epochs 61.825 iters/sec. Estimated time to finish: 0:03:56.958209.
また、結果として、出力されるresult配下のディレクトリは次のとおりです。
-rw-r--r-- 1 Tereka staff 2250 10 24 23:33 cg.dot -rw------- 1 Tereka staff 2590 10 24 23:39 log -rw------- 1 Tereka staff 4775680 10 24 23:39 snapshot_iter_18750
DatasetMixinを使った拡張
ImageNetのサンプルにありますが、Real Time Augmentationを行うことができます。
これを応用すると様々な用途で使うことができて非常に便利です。
例えば、ファイルを順次読み出したい場合に
ファイルをデータとして渡しておき、それを処理するタイミングで順次読み出すことができます。
また、データを加工することも自由にできるため、自由にデータに変換を加えることができます。
chainer.dataset.DatasetMixinを使って以下のような拡張が可能です。
以下の拡張はシンプルです。chainer.dataset.DatasetMixin#get_exampleを使うと実現できます。
このメソッド内部にファイルを読み込む処理を作ると、
実際にファイルを読み込むことが可能となります。
例は次の通りです。
import skimage.io class DatasetExampleMixin(chainer.dataset.DatasetMixin): def __init__(self,X,y): self.X = X self.y = y def __len__(self): return len(X) def get_example(self,i): """ Fileを読み出す処理 """ return skimage.io.imread(X[i]), y[i]
最後に
Trainer凄く便利!これを使いこなしてTrainerを使えるChainer使いになりましょう。