アラカン"BOKU"のITな日常

文系システムエンジニアの”BOKU”が勉強したこと、経験したこと、日々思うことを書いてます。

保存した学習済モデルをJAVAで利用できるようにマージする:Tensorflow入門の入門8/文系向け

今回は、pythonで学習済のグラフをJAVAで再利用して推論等に使う方法の2回めです。 

今回やることは、エクスポートした学習済モデルと変数状態のデータをマージして、Tensorflow for JAVAで読み込んで再利用できるようにすることです。 

pythonで学習したモデルをJAVA側で再利用する手順は、少々長いので、記事を3回にわけて書いてます。

 

3-1:pythonで学習済モデルをJAVAで再利用できる形式で保存する。


3-2:保存した学習済モデルをJAVAで利用できるようにマージする(この記事)


3-3:マージした学習済モデルをJAVA側で読み込んで推論する。

arakan-pgm-ai.hatenablog.com

 

使うデータは前回、保存したものです。  

今回は、その2回目になります。

 

保存した学習済モデルをJAVAで利用できるようにマージする

 

さて、ここでマージするツールが必要になります。 

実は、エクスポートしたchekpointsデータと、GrapfDefデータ(pbファイル)をマージするツールは標準で用意されています。 

インストールした ”tensorflow\tensorflow\python\tools” フォルダの下にある、freeze_graph.py というツールがそれです。 

ただ、今回は結局使いませんでした。 

理由は単純です。 

面倒くさかったからです。 

自分はpythonの統合環境「IDLE」を使ってやってます。 

実行する時はソースを表示して、F5キーを押すだけという操作が身に染み付いてます。 

ところが、このfreeze_graph.pyをコマンドラインから実行しようと思うと、IDLEを一回終了させないといけなかったんですね。 

たいした手間ではないですけど、それが「面倒くさかった」んです。 

だから、ソースに、freeze_graphをインポートして動かすプログラムを作って、それを開いてF5キーで実行・・とやろうとしたんですけど、非常に汎用的に作られているので、引数が非常に多いんですね。 

なので、ソースを読みながら、適切な引数の与え方を調べてたんですけど、それより、自前で必要な部分だけ抜粋して、同等の機能を実装した方がはやいな・・と思ったわけです。 

その方が、勉強にもなるし。 

ということで、以下、自前で作った簡易版「freeze_graph」を作っていきます。 

 

自前版freeze_graphを作る

 

まず、やるべきことをまとめて見ます。 

前回、エクスポートでできたファイルは以下でした。

  • checkpoint
  • model.ckpt.data-00000-of-00001
  • model.ckpt.index
  • model.ckpt.meta
  • testmodel.pb

 

上の4つが、checkpoint(変数の状態をシリアライズしたもの)をエクスポートしたもの、最後のひとつがGraphDef(Graphをシリアライズしたもの)をエクスポートしたものですね。 

マージでやるべきことは

  1. testmodel.pbをインポートして、Graphを再構築する。
  2. checkpointをインポートして、変数の状態を再構築する。
  3. 変数を一旦、constant(定数)に置き換える
  4. その状態のGraphをエクスポートする

です。 

つまり、変数のままだと、Graphをエクスポートした時に「変数が初期化されていない状態」になってしまうので、学習済の値を保持したまま、変数を全部定数に置き換えてしまってから、もう一度、Graphをエクスポートしよう・・ということなんですね。 

文字とおり、freeze・・なわけです。 

先にソースを全文書いておきます。 

とりあえず版なので、ファイル名とか直接書いてて、関数化も中途半端なんですけど、そこはご容赦を。

import os
import tensorflow as tf
from tensorflow.python.framework import graph_util
def freeze_graph():
    cwd = os.getcwd()
    output_graph = cwd + "\\frozen_model.pb"
    output_node_names = "accuracy"
    clear_devices = True
    saver = tf.train.import_meta_graph(cwd + '\\model.ckpt.meta', clear_devices=clear_devices)
    with tf.gfile.FastGFile(cwd + "\\testmodel.pb","rb") as f:
        input_graph_def = tf.GraphDef()
        input_graph_def.ParseFromString(f.read())
    with tf.Session() as sess:
        saver.restore(sess, cwd + "\\model.ckpt")
        output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names.split("/") )
    with tf.gfile.GFile(output_graph, "wb") as f:
         f.write(output_graph_def.SerializeToString())

freeze_graph()

 

さて、ポイントを確認していきます。

 

testmodel.pbをインポートして、Graphを再構築する。

 

これをやっているのがこの部分です。

with tf.gfile.FastGFile(cwd + "\\testmodel.pb","rb") as f:
     input_graph_def = tf.GraphDef()
     input_graph_def.ParseFromString(f.read())

 

まあ、ファイルを開いて、読み込んで、シリアライズされているのをパースしてGraphに戻しているわけなので、まんまです。 

cwdにはカレントフォルダのパスがはいっています。 

Windows版のみこのような書き方が必要で、Linux版の場合はファイル名だけ、または ./ testmodel.pb でいいはずです。

 

checkpointをインポートして、変数の状態を再構築する。 

 

これをやっている部分です。

clear_devices = True
saver = tf.train.import_meta_graph(cwd + '\\model.ckpt.meta', clear_devices=clear_devices)

saver.restore(sess, cwd + "\\model.ckpt") 

 

ポイントは、saverオブジェクトを新規に生成するのではなくて、.metaファイルからインポートしているところです。 

そして、clear_devicesをTrueにすることで、GPU使用とかの別の環境で再利用するにはいらない情報をクリアしておきます。 

それでsaverオブジェクトを作ってしまったら、普通にrestoreするだけですね。

 

変数を一旦、constant(定数)に置き換える

 

これをやっている部分です。

output_node_names = "accuracy"

output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names.split("/") )

 

tensorflowのツールを使ってます。 

引数のsessはセッション、input_graph_defはさっきファイルから戻したGraphなので迷うところはありません。 

output_node_names = "accuracy"の部分は、保存したGraphで最終アウトプットにしているオペレーションの名前にあわせてます。 

output_node_names.split("/")は、今回の場合あまり意味がないのですけど、とりあえず、これで動いたので、こうしてます。 

ただ、保存する時にname="xxxx"ですべての変数に名前をつけてないと、Noneへのアクセスでエラーがでて怒られるので注意が必要です。

 

その状態のGraphをエクスポートする

 

この部分のコードです。

output_graph = cwd + "\\frozen_model.pb"
with tf.gfile.GFile(output_graph, "wb") as f:
     f.write(output_graph_def.SerializeToString())

 

さっき、変数を定数に置き換えたGraphをシリアライズして、ファイルに書き込んでいるだけです。

 

実行して結果を確認してみる

 

実行すると、frozen_model.pb というファイルが、カレントフォルダにできました。 

いけてるっぽいのですが、念のため、pythonで読み込んで実行してみようかと思います。 

結果確認用のコードです。

import tensorflow as tf
import csv
import math
import os

cwd = os.getcwd()
data_body_test,label_body_test = csv_loader("sample_test.csv")
with tf.gfile.FastGFile(cwd + "\\frozen_model.pb","rb") as f:
     graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def,name='')
s = tf.Session()
acc = s.run("accuracy:0", feed_dict={"data:0":data_body_test,"label:0":label_body_test})
print("結果:{:.2f}%".format(acc * 100))

 

超シンプルになりました。 

ほんとにGraphを読み込んで実行結果を得るだけです。 

注意すべきポイントは、Graphに含まれるオペレーションの指定方法です。

acc = s.run("accuracy:0", feed_dict={"data:0":data_body_test,"label:0":label_body_test})

 

name="data"で名前をつけたオペレーションに値をセットするには、"data:0"みたいに、":0"をつけて参照しないといけないのですね。 

これを付け忘れるとエラーになります。 

これも最初のうち、エラーメッセージからなかなか読み取れなくて苦労しました。(^_^;) 

さて実行してみましょう。 

不正解データもまぜてあるので、92~93%なら、きちんと学習モデルが受け渡されていると判断できます。

f:id:arakan_no_boku:20170513232641j:plain

 

バッチリです。

 

2017/12/09追記

いちおう、tensorflow v1.4で動作確認しました。

 

2018/02/12追記

tensorflow v1.5で動作確認しました。

 

次回はJAVAで学習済モデルを読み込むのに挑戦してみます。

arakan-pgm-ai.hatenablog.com

 


Tensorflow入門の入門の記事の一覧はこちらです。

arakan-pgm-ai.hatenablog.com

 

f:id:arakan_no_boku:20170404211107j:plain