Python
JavaScript
TensorFlow
TensorFlow.js

TensorFlow.jsでPython学習済モデルを読み込み、ブラウザで予測結果を出力する

はじめに

TensorFlow Developer Summit 2018 にて Webブラウザ上で機械学習のモデルの構築、学習、学習済みモデルの実行などが可能になるJavaScriptライブラリ「TensorFlow.js」がGoogleによって公開されました。
今はやりっぽいのでJavascriptの勉強もかねてちょっと動かしてみました。

目的

Python モデルをTensorFlow.js で読み込む方法メモ

作ったもの

Pythonのモデルを用いた推測結果をTensorFlow.jsのコンソール上で表示
作成物.png

コード

https://github.com/hiyashichuka/tfjs-iris
※以下を参考にしました
https://github.com/tensorflow/tfjs-examples/tree/master/iris

環境

  • OS

    • Windows 10
    • Bash on Ubuntu on Windows 14.04.5 LTS, Trusty Tahr
  • node

    • node v8.11.3
    • npm 5.6.0
  • Python

    • Python 3.5.5 |Anaconda custom (64-bit)|
  • pip list

    • tensorboard 1.8.0
    • tensorflow 1.8.0
    • tensorflow-gpu 1.8.0
    • tensorflow-hub 0.1.0
    • tensorflow-tensorboard 1.5.1
    • tensorflowjs 0.5.2

主な流れ

  1. Python学習済みモデルを作成 ⇒ Pythonで実行
  2. モデルをTensorFlow.jsで読み込める形に変換 ⇒ Pythonで実行
  3. 2の学習済みモデルを用いて、任意の花の測定値からアイリスのどの品種に属するかを予測 ⇒Javascriptで実行
  4. 予測結果をTensorFlow.jsで表示 ⇒Javascriptで実行

3,4を具体的に
任意の4データを学習済みモデルに入れて(Sepal Length(がく片の長さ), Sepal Width(がく片の幅), Petal Length(花びらの長さ), Petal Width(花びらの幅))'Setosa', 'Versicolor', 'Virginica'のどれに分類されるかをコンソール上で表示させる

1. Python学習済みモデルを作成

Irisクラス分類学習済みモデルを生成します。
今回モデルを作ることは目的ではないので、以下のGoogleのサンプルコードを使います

https://github.com/tensorflow/tfjs-examples/tree/master/iris/python

2. モデルをTensorFlow.jsで読み込める形に変換

1の iris_data.pyを実行します
実行環境がないならGoogle Colaboratoryを使うといいかもです

python iris_data.py

以下のファイルが\tmp\iris.kerasに生成されます
model.jsonを読み込みに使います     

group1-shard1of1  
group2-shard1of1 
model.json

   

3 学習済みモデルを用いて、任意の花の測定値からアイリスのどの品種に属するかを予測

データ読み込み

tf.loadModelからmodel.jsonを読み込みます。

  const MODEL_JSON_URL = /* model.jsonのPath */
  const model = await tf.loadModel(MODEL_JSON_URL);

テストデータを用意

今回はPetal length, Petal width, Sepal length, Sepal widthの四つのデータが必要なので、inputData で渡します

  // Input four date
  // Petal length, Petal width, Sepal length, Sepal width
  const inputData = [5.1, 3.5, 1.4, 0.2];
  const input = tf.tensor2d([inputData], [1, 4]);

予測

model.predictにinputデータを入れることで予測が可能です

  const predictOut = model.predict(input);

4. 予測結果をTensorFlow.jsで表示

結果出力

  const logits = Array.from(predictOut.dataSync());
  console.log("Setosa Probabilities :" + logits[0]);
  console.log("Versicolor Probabilities :" + logits[1]);
  console.log("Virginica Probabilities :" + logits[2]);

  const winner = IRIS_CLASSES[predictOut.argMax(-1).dataSync()[0]];
  console.log("Predict IRIS Class :" + winner);

上記結果

Setosa Probabilities :0.956368625164032
Versicolor Probabilities :0.04204682260751724
Virginica Probabilities :0.0015846537426114082
Predict IRIS Class :Setosa

Setosa > Versicolor > Virginica なのでSetosaクラスに分類されることがわかりました

メモ

  • tfjs-converterを使わない理由

GoogleがSavedModelを使うのを推奨しているため

(Note: TensorFlow has deprecated session bundle format, please switch to SavedModel.)
https://github.com/tensorflow/tfjs-converter

参考

tfjs-examples
https://github.com/tensorflow/tfjs-examples

TensorFlow.jsでMNIST学習済モデルを読み込みブラウザで手書き文字認識をする
https://qiita.com/kaneU/items/ca84c4bfcb47ac53af99

TensorFlow 1.7 新機能 サンプルコードまとめ
https://qiita.com/akimach/items/d150fce405aff37dd463#get-started.md