今回から、「python版Tensorflowで学習させた学習済グラフを、JAVAで再利用する。」をやってみます。
簡単な利用イメージはこんな感じです。
SpringBoot+Javaで動いているWEBアプリケーションに、ディープラーニングを使って推論する機能を組み込むのが目的です。
ただ、手順は若干おおいです。
一気にそれを書くと、めちゃくちゃ記事が長くなるので、3回にわけて書きます。
3-1:pythonで学習済モデルをJAVAで再利用できる形式で保存する。(この記事)
3-2:保存した学習済モデルをJAVAで利用できるようにマージする arakan-pgm-ai.hatenablog.com
3-3: マージした学習済モデルをJAVA側で読み込んで推論する。
今回は、1回目です。
pythonで学習済モデルをJAVAで再利用できる形式で保存する
Tensorflowでエクスポート(保存)できるモデル(core artifacts)には2種類あったんですね。
この2種類です。
で、前回の記事で紹介した、tensorflow.trani.Saver()で保存できるののはcheckpointsのみです。
GraphDefは、tf.train.write_graph()を使って、testmodel.pbなどの名前をつけた、.pb(プロトコルバッファ)ファイルにします。
Graphと変数の状態ですから、JAVA等で再利用しようと思ったら、この2つがセットで存在しないといけないわけです。
ここが仲々わからなくて、checkpointsだけでやろうとして、変数が未定義だと怒られ、pb(プロトコルバッファ)ファイルのみでやろうとして、変数が初期化してないと怒られ・・を繰り返して、1週間くらいはまりました。
ということで、前回から、いくつか保存方法を変更します。
Tensorflowで学習モデルを保存する方法の変更点
変更点を先にまとめときます。
- 全体を名前をつけたGraphで囲む。(色)
- すべてのTensorに参照用の名前をつける。(色)
- Graphを保存する処理を追加する。(色)
です。
先にソースコード全体をのせます。
上記の変更点にあたる部分には色をつけてます。
import tensorflow as tf
import csv
import math
import os#csvデータを読み込む部分は割愛してます。前回以前の記事参照。
gr = tf.Graph()
with gr.as_default():
bias = tf.Variable(tf.zeros([2],dtype=tf.float32),name="bias")
stddev_1 = 2.0 / math.sqrt(5 * 2)
weight = tf.get_variable("weight",[5,2],initializer=tf.random_normal_initializer(stddev=stddev_1,dtype=tf.float32))
bias_h = tf.Variable(tf.zeros([5],dtype=tf.float32),name="bias_h")
stddev_2 = 2.0 / math.sqrt(5 * 5)
weight_h1 = tf.get_variable("weight_h1",[5,5],initializer=tf.random_normal_initializer(stddev=stddev_2,dtype=tf.float32))
weight_h2 = tf.get_variable("weight_h2",[5,5],initializer=tf.random_normal_initializer(stddev=stddev_2,dtype=tf.float32))
weight_h3 = tf.get_variable("weight_h3",[5,5],initializer=tf.random_normal_initializer(stddev=stddev_2,dtype=tf.float32))
data = tf.placeholder(dtype=tf.float32,shape=[None,5],name="data")
label = tf.placeholder(dtype=tf.float32,shape=[None,2],name="label")
hidden_1 = tf.nn.relu(tf.matmul(data,weight_h1) + bias_h,name="hidden_1")
hidden_2 = tf.nn.relu(tf.matmul(hidden_1,weight_h2) + bias_h,name="hidden_2")
hidden_3 = tf.nn.relu(tf.matmul(hidden_2,weight_h3) + bias_h,name="hidden_3")
y = tf.nn.softmax(tf.matmul(hidden_3,weight) + bias,name="y")
cross_entropy = tf.reduce_mean(-tf.reduce_sum(label * tf.log(y), reduction_indices=[1]),name="cross_entropy")
train_step = tf.train.GradientDescentOptimizer(0.1,name="train_step").minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(label,1),name="correct_prediction")
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"),name="accuracy")
cwd = os.getcwd()
with tf.Session(graph=gr) as s:
s.run(tf.global_variables_initializer())
saver = tf.train.Saver()
for k in range(10):
s.run(train_step,feed_dict={data:data_body,label:label_body})
acc = s.run(accuracy, feed_dict={data:data_body_test,label:label_body_test})
print("結果:{:.2f}%".format(acc * 100))
saver.save(s,cwd + "\\model.ckpt")
gr_def = gr.as_graph_def()
tf.train.write_graph(gr_def,cwd,"testmodel.pb",as_text=False)
上記は、Windows版のtensorflowでやってます。
なので、「 cwd = os.getcwd()」でカレントフォルダのパスを取得して、いちいち絶対パスにしてからファイルを読み書きしています。
そうしないと、Windows版ではうまくいかないので。
Linux版だと、そういうのは必要なく、ファイル名のみまたは 「./model.ckpt」みたいな書き方でいけるので、そこは必要に応じて書き換えてください。
ポイントを確認していきます。
全体を名前をつけたGraphで囲む。
あとでGraphを保存しなければいけないので、扱いやすいように名前をつけてるだけなんですけどね。
gr = tf.Graph()
with gr.as_default():
こんな感じでグラフオブジェクトを作って、そのWith句でデフォルトのグラフを使うことを宣言してます。
CSVデータから読み込む処理とか、再利用するときにも外部で行う必要のある部分は、Graphで囲む前に書いてはずしておくほうがいいみたいですね。
すべてのTensorに参照用の名前をつける。
以下例みたいに、すべてのTensor(テンソル)/オペレーションに、name="xxx"で名前を追加してます。
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"),name="accuracy")
pythonでもJAVAでも、学習済モデルを再利用する時には、こちらの名前でアクセスする必要があるからです。
この名前をつけておかないと、実行したときに、「Noneにアクセスしようとしてるぞ!」と怒られます。
これがまた、原因がわかりづらいエラーになります。
ですから、とりあえずエラーが回避する方を優先して、すべてに名前をつけておくようにしています。
Graphを保存する処理を追加する。
前回で、checkpointsを保存する処理は書いてあるので、今回はGraphを保存する処理だけを追加で書きます。
gr_def = gr.as_graph_def()
tf.train.write_graph(gr_def,cwd,"testmodel.pb",as_text=False)
現在のグラフから、GraphDef・・ようするにGraphをシリアライズしたもの・・を生成して、それを、tf.train.write_graph()を使って、testmodel.pbという名前のファイルに書き出しているだけです。
エクスポートされたファイルの確認
うまく動いていれば、以下のようなファイルができているはずです。
- checkpoint
- model.ckpt.data-00000-of-00001
- model.ckpt.index
- model.ckpt.meta
- testmodel.pb
上の4つがチェックポイントのエクスポートデータ、最後の一つがGraphのエクスポートデータです。
あとは、これをJAVA側で読み込んでやればいいな。
と、そう思ったら、いやいやとんでもない。
実は、チェックポイントのエクスポートデータをJAVA側でどうやっても読み込めなかったんです。
どうも、現時点では
このファイルを全部マージして、変数の状態を反映させたpb(プロトコルバッファ)ファイルを作ってやる。
または、全く別の形式でエクスポートして、SavedModelBundleを利用して、JAVA側でロードする方法をとる。
このどちらかを行う必要があるみたいなんです。
ただ、SavedModelBundleの関連機能は、正直まだ tf.contrib.XXXの機能と同様に、依然開発中っぽいという情報があります。
それに、あとあとの事を考えると、とりあえず、古い(ベーシックな)やり方も知っておかないと、応用がきかなったら嫌なので、今回は前者の方法をとることにしました。
2017/12/09追記
>いちおう、tensorflow v1.4で動作確認しました。
2018/02/12追記
>tensorflow v1.5で動作確認しました。
このマージする方法については、次回の記事(3-2)で書きます。
Tensorflow入門の入門カテゴリの記事一覧はこちらです。