ナード戦隊データマン

データサイエンスを用いて悪と戦うぞ

機械学習で作成したモデルをREST APIとしてdeployする[python]

モデルを運用する方法の一つとして、REST APIがあります。これは、予測したいデータをWebベースのAPIに送信することで、予測値のレスポンスを取得する方法です。ここでは、pythonでflaskを用いて試します。

joblib.dumpでモデルを保存する

作成したモデルをjoblib.dumpを用いてファイルとして保存することができます。

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import pandas as pd

iris = datasets.load_iris()
X = pd.DataFrame(iris.data, columns=["sepal_length", "sepal_width", "petal_length", "petal_width"])
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=42)

clf = LogisticRegression()
clf.fit(X_train, y_train)
#print(clf.score(X_test, y_test))

from sklearn.externals import joblib
joblib.dump(clf, 'iris_logreg.pkl')
joblib.dump(["sepal_length", "sepal_width", "petal_length", "petal_width"], 'iris_logreg_cols.pkl')

ここでは、特徴量名の配列もdumpしています。これは、APIを呼び出される側で、特徴量名の順番を保持する必要があるためです。Rとは違い、特徴量の名前ではなく、順番で認識しています。

FlaskでAPIを作成する

ここで作成するAPIの大雑把な仕様は以下です。

  1. モデルファイル名をURLに渡す。
  2. 指定したモデルファイルに対し、予測したいデータを渡すと予測値の配列がjson形式で返る。
from sklearn.externals import joblib
from flask import Flask, jsonify, request
import pandas as pd
from sklearn import datasets

app = Flask(__name__)

@app.route('/predict/<string:clf_file>', methods=['POST'])
def predict(clf_file):
    clf = joblib.load("{}.pkl".format(clf_file))
    data = request.json
    query = pd.DataFrame(data)
    cols = joblib.load("{}_cols.pkl".format(clf_file))
    query = query[cols]
    prediction = clf.predict(query)
    return jsonify({'prediction':prediction.tolist()})


if __name__ == '__main__':
    app.run(port=8080)

これを、mlapi.pyと名付け、以下のコマンドで実行します。

$ python mlapi.py &

requestを送って試してみる

import requests
import pandas as pd
import numpy as np

send_data = []
for index, X_row in X_test.iterrows():
    row = {"sepal_length": X_row['sepal_length'], "sepal_width":X_row['sepal_width'], "petal_length": X_row['petal_length'], "petal_width": X_row['petal_width']}
    send_data.append(row)
    
r = requests.post("http://localhost:8080/predict/iris_logreg", json=send_data)
print(r.text)
print(y_test)
{
  "prediction": [
    1,0,2,1,1,0,1,2,1,1,2,0,0,0,0,1,2,1,1,2,0,2,0,2,2,2,2,2,0,0
  ]
}

[1 0 2 1 1 0 1 2 1 1 2 0 0 0 0 1 2 1 1 2 0 2 0 2 2 2 2 2 0 0]

複数のモデルを同時に運用するには

このAPIは、モデルファイル名を渡すことによって機能しています。モデルファイル名は、特徴量名の順番を保持した配列ファイルにも利用されるため、ファイル名のルールさえ守れば、複数のモデルを運用することが可能です。

URLの形式は、以下のようになっています。

/predict/<string:clf_file>

<string:clf_file>の部分に、モデルをダンプしたファイル名を渡すことができます。

参考

medium.com