アラカン"BOKU"のITな日常

文系システムエンジニアの”BOKU”が勉強したこと、経験したこと、日々思うことを書いてます。

学習済グラフを保存して、再利用する(python版):tensorflow入門の入門6/文系向け

今回は、学習済のグラフを保存して、再利用できるようにしてみます。 

ディープラーニングは、とにかく学習に時間がかかりますからね。 

裏でじっくり学習させて、例えば、アプリケーション等で使う時には、最適化された学習済のグラフを保存して再利用できないと、何の役にもたちません。 

保存対象にする学習済のグラフは、前回構築・学習したものを使います。

arakan-pgm-ai.hatenablog.com

 

学習結果の保存と再利用(読み込み)する機能

 

学習結果の保存と再利用(読み込み)する機能は、Tensorflowで、ちゃんと用意されてます。・・当たり前ですね。 

tensorflow.train.Saver() です。 

命令は、save() で保存、restore() で読み込みです。 

超簡単です。 

ただ、それを書く位置と順序に注意事項があるみたいなので、今回はその辺の確認をする感じなるのかな。

 

学習結果を保存する

 

保存機能を付け加えたコードを抜粋してみます。 

ベースは前回のコードで、付け加えた部分だけ、色を変えてます。

with tf.Session() as s:
     s.run(tf.global_variables_initializer())
     saver = tf.train.Saver()
     cwd = os.getcwd()
     for k in range(10):
          s.run(train_step,feed_dict={data:data_body,label:label_body})
     acc = s.run(accuracy, feed_dict={data:data_body_test,label:label_body_test})
     print("結果:{:.2f}%".format(acc * 100))
     saver.save(s,cwd + "\\model.ckpt")

 

1行ずつ確認していきます。 

まず、Saver()のオペレーションを作ります。

 saver = tf.train.Saver()

 

注意点としては以下の2点です。 

  • Sessionの中でなければならない。
  • tf.global_variables_initializer()の実行後でなければならない。

 

2つ目は、変数を使っている時だけです。 

もし、tf.global_variables_initializer()の実行前に書くと、実行時に変数がない(No variables to save)とエラーがでます。 

次の行は、Windows版のTensorflowの時だけは必ず必要です。

cwd = os.getcwd()

 

ネットのサンプルプログラムは、Linux版ベースで書かれているものが多いので、この処理は書いていません。 

でも、Windows版ではこうしないと、ディレクトリが見つからない(ValueError: Parent directory of model.ckpt doesn't exist, can't save.)と怒られます。

 

で、最後にsave()コマンドで保存します。

saver.save(s,cwd + "\\model.ckpt")

 

上記はWindows版のケースです。 

Linux版だと以下の書き方でもいけるはずですが、Windows版はこれだとエラーになります。

saver.save(s, "model.ckpt")

 

Windows版の場合は、絶対パスで指定してやらないと、ディレクトリを見つけられないみたいですね。 

細かいとこですけど、意外にはまりそうな箇所ではありますね。 

これで実行して成功すると、カレントディレクトリに以下のように最低4つのファイルができます。

f:id:arakan_no_boku:20170514234621j:plain

 

読み込んで再利用する

 

じゃあ、先程保存した学習済パラメータを使って、推論をやってみて前回と同じ結果がでるかどうか試してみます。 

学習済パラメータを読み込んだだけで、全く学習を行わないで、推論だけやるコードを抜粋します。 

ベースは前回のコードで、付け加えた部分だけ、色を変えてます。

with tf.Session() as s:
     s.run(tf.global_variables_initializer())
     saver = tf.train.Saver()
     cwd = os.getcwd()
     saver.restore(s,cwd + "\\model.ckpt")
     acc = s.run(accuracy, feed_dict={data:data_body_test,label:label_body_test})
    print("結果:{:.2f}%".format(acc * 100))

 

最初の2行は保存のときと同じです。 

読み込んで、パラメータをリストアしているのは以下の部分です。

saver.restore(s,cwd + "\\model.ckpt")

 

Windows版だけ絶対指定が必要なのも同じです。 

まあ、saveと全く同じ形で、restoreという名前なので、別に説明しなくても何をしているかはわかりますね。 

これで、学習済パラメータが再現されていれば、テストデータで分類した結果が前回と全く同じになるはずです。

 

さて、どうかな。

f:id:arakan_no_boku:20170513232641j:plain

 

おお!完璧じゃないですか。

 

2017/12/09追記

tensorflow v1.4で動作確認したところ、古いバージョンで保存したcheckpointファイルのリストアでエラーになるものがありました。 

こんなエラーメッセージです。

NotFoundError (see above for traceback): Key Variable_1 not found in checkpoint

 

とりあえず、v1.4でsave()しなおしてからresotore()と、問題なくなりますが、要注意ですね。

 

2018/02/12追記

tensorflow v1.5で動作確認時、v1.4で保存したものでリストアしたら、正解率が下がりました。(50.0%)

 

v1.5でsave()して、やり直したら問題なくなりました。 

どうも、バージョンが変わると、restore()は何かとあるみたいです。

 


Temsorflow入門の入門カテゴリの記事一覧はこちらですarakan-pgm-ai.hatenablog.com

 

f:id:arakan_no_boku:20170404211107j:plain