今回は、学習済のグラフを保存して、再利用できるようにしてみます。
ディープラーニングは、とにかく学習に時間がかかりますからね。
裏でじっくり学習させて、例えば、アプリケーション等で使う時には、最適化された学習済のグラフを保存して再利用できないと、何の役にもたちません。
保存対象にする学習済のグラフは、前回構築・学習したものを使います。
学習結果の保存と再利用(読み込み)する機能
学習結果の保存と再利用(読み込み)する機能は、Tensorflowで、ちゃんと用意されてます。・・当たり前ですね。
tensorflow.train.Saver() です。
命令は、save() で保存、restore() で読み込みです。
超簡単です。
ただ、それを書く位置と順序に注意事項があるみたいなので、今回はその辺の確認をする感じなるのかな。
学習結果を保存する
保存機能を付け加えたコードを抜粋してみます。
ベースは前回のコードで、付け加えた部分だけ、色を変えてます。
with tf.Session() as s:
s.run(tf.global_variables_initializer())
saver = tf.train.Saver()
cwd = os.getcwd()
for k in range(10):
s.run(train_step,feed_dict={data:data_body,label:label_body})
acc = s.run(accuracy, feed_dict={data:data_body_test,label:label_body_test})
print("結果:{:.2f}%".format(acc * 100))
saver.save(s,cwd + "\\model.ckpt")
1行ずつ確認していきます。
まず、Saver()のオペレーションを作ります。
saver = tf.train.Saver()
注意点としては以下の2点です。
- Sessionの中でなければならない。
- tf.global_variables_initializer()の実行後でなければならない。
2つ目は、変数を使っている時だけです。
もし、tf.global_variables_initializer()の実行前に書くと、実行時に変数がない(No variables to save)とエラーがでます。
次の行は、Windows版のTensorflowの時だけは必ず必要です。
cwd = os.getcwd()
ネットのサンプルプログラムは、Linux版ベースで書かれているものが多いので、この処理は書いていません。
でも、Windows版ではこうしないと、ディレクトリが見つからない(ValueError: Parent directory of model.ckpt doesn't exist, can't save.)と怒られます。
で、最後にsave()コマンドで保存します。
saver.save(s,cwd + "\\model.ckpt")
上記はWindows版のケースです。
Linux版だと以下の書き方でもいけるはずですが、Windows版はこれだとエラーになります。
saver.save(s, "model.ckpt")
Windows版の場合は、絶対パスで指定してやらないと、ディレクトリを見つけられないみたいですね。
細かいとこですけど、意外にはまりそうな箇所ではありますね。
これで実行して成功すると、カレントディレクトリに以下のように最低4つのファイルができます。
読み込んで再利用する
じゃあ、先程保存した学習済パラメータを使って、推論をやってみて前回と同じ結果がでるかどうか試してみます。
学習済パラメータを読み込んだだけで、全く学習を行わないで、推論だけやるコードを抜粋します。
ベースは前回のコードで、付け加えた部分だけ、色を変えてます。
with tf.Session() as s:
s.run(tf.global_variables_initializer())
saver = tf.train.Saver()
cwd = os.getcwd()
saver.restore(s,cwd + "\\model.ckpt")
acc = s.run(accuracy, feed_dict={data:data_body_test,label:label_body_test})
print("結果:{:.2f}%".format(acc * 100))
最初の2行は保存のときと同じです。
読み込んで、パラメータをリストアしているのは以下の部分です。
saver.restore(s,cwd + "\\model.ckpt")
Windows版だけ絶対指定が必要なのも同じです。
まあ、saveと全く同じ形で、restoreという名前なので、別に説明しなくても何をしているかはわかりますね。
これで、学習済パラメータが再現されていれば、テストデータで分類した結果が前回と全く同じになるはずです。
さて、どうかな。
おお!完璧じゃないですか。
2017/12/09追記
tensorflow v1.4で動作確認したところ、古いバージョンで保存したcheckpointファイルのリストアでエラーになるものがありました。
こんなエラーメッセージです。
NotFoundError (see above for traceback): Key Variable_1 not found in checkpoint
とりあえず、v1.4でsave()しなおしてからresotore()と、問題なくなりますが、要注意ですね。
2018/02/12追記
tensorflow v1.5で動作確認時、v1.4で保存したものでリストアしたら、正解率が下がりました。(50.0%)
v1.5でsave()して、やり直したら問題なくなりました。
どうも、バージョンが変わると、restore()は何かとあるみたいです。
Temsorflow入門の入門カテゴリの記事一覧はこちらですarakan-pgm-ai.hatenablog.com