アラカン"BOKU"のITな日常

文系システムエンジニアの”BOKU”が勉強したこと、経験したこと、日々思うことを書いてます。

マージした学習済モデルをJAVAで読み込んで推論する:Tensorflow入門の入門9/文系用

今回はpythonで学習させた「学習済グラフ」を保存したファイルを使って、JAVAでグラフを再構築して、評価などの機能に利用する方法をやります。 

3回にわけて書いてきた記事の3回目です。 

初めて見る方は、学習済モデルの保存(3-1)と、JAVAで利用できるようにマージする(3-2)から続けてみてもらった方が、わかりよいと思います。

 

3-1:pythonで学習済モデルをJAVAで再利用できる形式で保存する。

 

3-2:保存した学習済モデルをJAVAで利用できるようにマージする

 

3-3:マージした学習済モデルをJAVA側で読み込んで推論する。(この記事。)

 

さて、まずはJAVA側で利用する準備です。

 

マージした学習済モデルをJAVA側で読み込んで推論する

 

自分の開発環境はSTSで、フレームワークとしてSpringBootを使ってます。 

そこでTensorflow連携テスト用のプロジェクトを、Mavenベースで作ってます。 

その前提下での情報であることはご容赦ください。 

もし、Mavenベースでない環境をお使いの場合はお手数ですが、こちらの記事を見てもらって、環境構築をお願いします。


Mavenベースの場合、Tensorflowを使えるようにするのは、pom.xmlに以下を追加するだけです。ただし、自分はwindows版のTensorflowを使っていることに注意してください。

<dependency>
     <groupId>org.tensorflow</groupId>
     <artifactId>tensorflow</artifactId>
     <version>1.1.0-rc0-windows-fix</version>
</dependency>

 

Windows版以外なら、Googleで「Maven tensorflow」とでも検索すれば設定情報はすぐ見つかりますので、それを参照してくださいね。 

 

2017/12/09追記

>現在はtensorflowのバージョンはv1.4です。

>そのため、v1.4で保存したモデルは、java版もv1.4でないとリストアできません。

<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow</artifactId>
     <version>1.4.0</version>
</dependency>

>以前は、「1.1.0-rc0-windows-fix」でしたが 、今は1.4.0だけでよいみたいです。

 

2018/02/12追記

>v1.5にバージョンアップしたので、 <version>1.5.0</version> に変更します。

 

JAVAでTensorflowの学習済モデルを利用する

 

さて、やっと本題です。 

学習済モデルは、前回マージした「frozen_model.pb」を使います。 

やることは以下の通りです。

  1. frozen_model.pbを読み、Graphオブジェクトを構築する。
  2. 構築したGraphでSessionオブジェクトを構築する。
  3. テスト用データのTensorオブジェクトを構築する。
  4. 推論を実行する。
  5. オブジェクトをクローズしてメモリを解放する。

 

最初にソースを全文載せて、ポイントにわけて確認する方式でやっていきます。

import java.io.File;
import java.nio.file.Files;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

public class ExampleTf {
    public static void main(String args) throws Exception {
        final File modelFile = new File("C:\\XXX\\frozen_model.pb");
        byte
graphDef = Files.readAllBytes(modelFile.toPath());
       Graph graph = new Graph();
       graph.importGraphDef(graphDef);
       Session session = new Session(graph);
       float d = new float{new float{20.0f ,20.0f ,20.0f ,0.0f ,0.0f}};
       float
l= new float{new float{1.0f ,0.0f}};
      Tensor<Float> data = Tensor.create(d,Float.class);
      Tensor<Float> label = Tensor.create(l,Float.class);
       Tensor acc = session.runner().feed("data", data).feed("label",label).fetch("accuracy").run().get(0);
       System.out.println("結果:" + acc.floatValue());
       data.close();
       label.close();
       acc.close();
       session.close();
       graph.close();
    }
}

 

2017/12/09追記

 >Tensorの初期化の部分で警告がでていたので、ちょっと訂正しました。

Tensor<Float> data = Tensor.create(d,Float.class); の部分です。

>ただ、「Tensor acc = session.runner().feed・・・」の部分はそのままです。

>だから、Tensorはraw型です・・という警告が、1個だけ残ります。

>ここは、google本家のサンプルでもこうなってます。

https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/Session

>なので、しょうがないのかな・・と。

 

2018/02/12追記

>v1.5になってもそこはそのままです。 

 

 frozen_model.pbを読み、Graphオブジェクトを構築する

 

この部分のコードです。 

サンプルとしてのわかりやすさを優先してフルパスなどを直接書いてたりしてますが、そのへんの手抜きはご容赦ください。 

なお、このパスの書き方はWindows専用(Windows版Tensorflow for java専用)です。

final File modelFile = new File("C:\\XXX\\frozen_model.pb");
byte graphDef = Files.readAllBytes(modelFile.toPath());
Graph graph = new Graph();
graph.importGraphDef(graphDef);

 

Fileオブジェクトを構築し、内容をbyte配列で読み込んでます。 

ファイルから読み込んだByte配列は、シリアライズされた状態のGraphですから、空のGraphオブジェクトを生成して、インポートしてやる必要があります。  

まんまですね。

 

構築したGraphでSessionオブジェクトを構築する

 

この部分のコードです。

Session session = new Session(graph);

 

Sessionオブジェクトを構築して、runner()で実行しなければならないのは、python版と同じです。

 

テスト用データのTensorオブジェクトを構築する

 

とりあえず、わかりやすさ優先でテストデータは1件だけ手作りにしてみます。 

今回の場合は、python側のデータ・タイプが「tf.float32」なので、Java側では floatになります。 

ただ、Tensorflow for java で引数などに使えるのは、やっぱり基本的にはTensorだけですので、データを作成する手順としては以下になります。

  1. 対象のTensorのRankに合わせてfloatのデータを作る。
  2. 上記データで、Tensorオブジェクトを構築する。

 

対象のTensorのRankに合わせてfloatのデータを作るコードです。

上段はデータ、下段がラベルです。

float d = new float{new float{20.0f ,20.0f ,20.0f ,0.0f ,0.0f}};
float l= new float{new float[]{1.0f ,0.0f}};

 

上記データで、Tensorオブジェクトを構築する部分のコードです。

同様に上段はデータ、下段がラベルです。

Tensor data = Tensor.create(d);
Tensor label = Tensor.create(l);

 

推論を実行する

 

とりあえず、1件だけですがデータを与えて推論を実行します。 

その部分と、結果を出力しているコードです。

Tensor acc = session.runner().feed("data", data).feed("label",label).fetch("accuracy").run().get(0);
System.out.println("結果:" + acc.floatValue());

 

Sessionオブジェクトのrunner()を使います。 

feed("data",data)の第一引数は、python側でGraphを保存する際につけた名前(この例だと、"data")を指定します。 

feed()は、pythonのコードでのfeed_dictに対応します。 

fetch("accuracy")でも、同様に保存時にオペレーションにつけた名前を指定します。 

結果はTensorオブジェクトで返されますので、そこからfloatValue()で値をとりだして、コンソールに出力しています。

 

オブジェクトをクローズしてメモリを解放する

 

ここが一番重要です。 

Graph、Session,Tensorなどのオブジェクトを構築したら、必ず明示的にClose()しなければなりません。 

自動的にメモリが開放されることはありません。 

メモリリークとかの原因にもなりかねないので、必ず、Close()する癖はつけといたほうが良さそうですね。

 

試してみます。

 

実行してみました。

結果:1.0

 

いけてそうですね。 

なお、実行すると以下のようなワーニングメッセージが4行くらい必ず出力されます。

The TensorFlow library wasn't compiled to use SSE instructions, but these are available on your machine and could speed up CPU computations.

エラーではありませんから、気にしなくても良いです。

 

まとめ

 

Tensorflow for java と、python版の両方で作業をした感覚としては、圧倒的にpython版の方が機能も豊富ですし、使いやすいです。 

コードもシンプルに書けますしね。 

だから、ニューラルネットワークを構築したらい、学習するような複雑な処理は、python版でやった方が絶対良いです。 

でも、Webシステムとかで何かを入力させて、それを学習済モデルで推論して結果を返すような仕組みを作るなら、フロント部分はJAVAで開発する方が、圧倒的に生産性が高いです。 

だから、python版でバックグラウンドで学習させて、学習済モデルを連携して、Javaで開発したWebシステム側で使うという役割分担ができればいいなと思ってました。 

今回確認できたことで、それがやれることがわかったのは、うれしいですねえ。

 

2017/12/09追記

いちおう、tensorflow for java v1.4で動作確認しました。

 

2018/02/12追記

tensorflow v1.5で動作確認しました。

 


Tensorflow入門の入門カテゴリの記事一覧はこちらです。

arakan-pgm-ai.hatenablog.com

f:id:arakan_no_boku:20170404211107j:plain