6

投稿日

更新日

機械学習向けWeb UIライブラリ Gradio

はじめに

Gradioは機械学習モデルを操作するためのWeb UIを簡単に作成できるPythonライブラリです。

例えば、ファイルから画像を読み込んで分類するUIを作るとします。そこには画像をアップロードするためのフォームや、分類結果を表示するためのテキストフィールドが必要です。
Gradioを使えば、これらの要素を含むUIを簡単に作成することができます。

以下はHuggingFaceのWebサイトに公開されている画像分類モデルを使って、画像の分類結果を表示するUIを作成した例です。examples引数で与えているのはUIに渡せるサンプル画像です。

import gradio as gr

gr.Interface.load(
    "huggingface/google/vit-base-patch16-224",
    examples=["./images/cat.jpg", "./images/man.jpg"]).launch()

このコードを実行するとhttp://localhost:7860でUIが動きます。この画面のキャプチャを以下に示します。

sample_1.png

基本的なUIの作り方

HuggingFaceで公開済みのモデルの場合は、先のようにモデル読み込み機能を使うことでUIを構築できます。

では自分がローカルで作成したモデルのUIを作るにはどうすればいいでしょう。そのためにまずは、UI構築用クラスのgradio.Interfaceのインスタンスの基本的な作り方を見ていきましょう。

以下は、https://gradio.app/quickstart/で公開されているうちで最も単純なgradio.Interfaceの作成方法です。

import gradio as gr

def greet(name):
    return "Hello " + name + "!"

demo = gr.Interface(fn=greet, inputs="text", outputs="text")

demo.launch()   

Interfaceのコンストラクタ引数の意味は下記です。

  • fn: 入力を処理して出力を生成する関数。
  • inputs: 入力データの型(ex. number)やその入力形態(ex. スライダーバー)を指定する引数。上記ではtext型を指定。リストを使えば複数のデータの型や入力形態を指定可。
  • outputs: 出力データの型や出力形態を指定する引数。inputs同様に、リストを使えば複数指定可。

したがって上記のコードによって、text型の入力を受け取ってgreet関数に渡し、その結果をtextとして出力するUIを作ることができます。実際には下記のUIが生成されます。

sample_2.png

nameの欄に文字列を入れて送信ボタンを押すと、outputgreet関数の戻り値が表示されます。
(よく見ると気の利くことに、greet関数の引数名のnameが入力欄に反映されています)

sashie.png

自作モデルのUIの作り方

前章では、①入力を処理する関数、②入力のデータ型や入力形態、③出力のデータ型や出力形態の3点セットがあれば、gradioでUIを作れることを示しました。

これを踏まえて本章では、自作モデルを使った推論を行うUIの作り方を示します。

数値を直接入力させてみる

まずirisデータの分類をKNNで学習した自作モデルと、それを使って説明変数から目的変数を推論する関数を作ります。

import gradio as gr
import numpy as np
from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier

# Load iris data
X, y = load_iris(return_X_y=True, as_frame=True)

print(X.columns)


# Train a KNN classifier
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X, y)

def predict(sepal_length, sepal_width,
            petal_length, petal_width):
    mat = np.array([sepal_length, sepal_width,
                    petal_length, petal_width])
    mat = mat.reshape(1, mat.shape[0])
    df = pd.DataFrame(mat, columns=['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'])
    res = knn.predict(df)
    return res[0]

推論用関数と入出力の定義を渡してInterfaceのインスタンスを作ります。

# Define the interface
output_placeholder = gr.outputs.Label()
interface = gr.Interface(predict, ['number', 'number', 'number', 'number'],
                         output_placeholder)

最後にUIを起動します。

# Launch the interface
interface.launch()

起動した画面は下記のようになっているはずです。

sample_4.png

入力をスライダーバーに差し替える

先の例では、説明変数に相当する数値を直接書き込んで入力する必要がありました。

これは少し面倒ですね。そこで以下2点の改修を行おうと思います。

  1. 入力形態をスライダーバーに差し替える
  2. 各説明変数のデフォルト値を与える

そのためにInterfaceのコンストラクタ引数を下記のように差し替えます。(差し替え場所がわかりやすいよう、差し替え前のコードをコメントアウトで残しています)

# Define the interface
sepal_length = gr.inputs.Slider(
    minimum=1, maximum=10, default=X['sepal length (cm)'].mean(), label='sepal_length')
sepal_width = gr.inputs.Slider(
    minimum=1, maximum=10, default=X['sepal width (cm)'].mean(), label='sepal_width')
petal_length = gr.inputs.Slider(
    minimum=1, maximum=10, default=X['petal length (cm)'].mean(), label='petal_length')
petal_width = gr.inputs.Slider(
    minimum=1, maximum=10, default=X['petal width (cm)'].mean(), label='petal_width')
output_placeholder = gr.outputs.Label()
interface = gr.Interface(predict,
                         # ['number', 'number', 'number', 'number'],
                         [sepal_length, sepal_width,
                          petal_length, petal_width],
                         output_placeholder)

このUIを起動すると下記のようになります。

sample_5.png

ラベルごとの確率を出力する

outputの欄にラベルと確率の組を列挙したい場合は、先に定義したpredict関数の内容を下記に差し替えます。

def predict(sepal_length, sepal_width,
            petal_length, petal_width):
    mat = np.array([sepal_length, sepal_width,
                    petal_length, petal_width])
    mat = mat.reshape(1, mat.shape[0])
    df = pd.DataFrame(mat, columns=['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'])
    # res = knn.predict(df)
    # return res[0]
    res = knn.predict_proba(df)
    res_dict = {}
    for i in range(len(res[0])):
        # NOTE: クラスと確率の組をdictで返すとラベルに両方とも綺麗に表示される。
        # そのときkeyにするクラスの型は文字列にするのが望ましい。
        res_dict[str(i)] = res[0][i]
    return res_dict

そうすると、outputの出力が下図右側のように差し変わります。

sashie.png

フラグ機能とは

これまで見てきたUIには、謎の「フラグする」というボタンがありました。この使い方について説明したいと思います。

gradioのフラグとは、推論結果を保存する機能のことです。
あとで振り返りたいと思った推論結果や、入力例として再利用したい推論結果を保存できます。

先に述べた「フラグする」ボタンを押せば、推論結果を保存できます。
保存先はデフォルトではPythonの実行フォルダの直下のflaggedです。この下には入出力データの即値やファイルパスをまとめたlog.csvが生成されます。log.csvにパスが書かれるファイルは例えば入力画像のファイルであったり、多値分類の結果をラベルと確率を対にして保存したJSONファイルです。

flagged/ディレクトリにある推論結果を入力例として再利用したい場合、Interfaceのコンストラクタ引数exampleにディレクトリパスを指定する必要があります。

interface = gr.Interface(predict,
                         [sepal_length, sepal_width,
                          petal_length, petal_width],
                         output_placeholder,
                         examples='flagged/')

そうすると下記のように入力例がUI下部に表示されます。テーブルデータの場合はそのまま表形式として出力されます。入力したい値のある行をクリックすればその値をモデルに与えられます。

sample_6.png

フラグ機能を使いたくない場合は、下記のようにInterfaceのコンストラクタでallow_flagging='never'を指定します。UIの具体例は割愛しますが、この指定によってこれまで出力していた「フラグする」ボタンを非表示にできます。

interface = gr.Interface(predict,
                         [sepal_length, sepal_width,
                          petal_length, petal_width],
                         output_placeholder,
                         allow_flagging='never')

より本格的なWeb UIを作るには

これまではUIから得た入力をそのまま関数に引き渡して、その戻り値をUIに表示するという単純なWeb UIを作ってきました。
しかし、より本格的なWeb UIを作りたい人は例えば以下のような疑問を持つのではないかと思います。

  • 一つの画面に複数のボタンを置いて、それぞれに異なるイベントハンドラを定義できないの?
  • セッションごとに何らかの値を引き継ぐことができるの?
  • sklearn以外にも、PyTorchのモデルなんかもGradioから呼び出せるよね?

そういう疑問を解消するには、公式のQuickstartが非常に役立ちます。

複数のイベントハンドラの定義

https://gradio.app/blocks_and_event_listeners/#multiple-data-flowsのコード例が典型だと思います。とてもわかりやすい。

JavaScriptで簡単なイベントハンドラを記述することもできます。

セッションの定義

セッションは入出力のstateデータで保持することができます。
stateには(おそらく)任意の型の値を設定することができます。

https://gradio.app/interface_state/#session-stateには、チャットログをリストで保持するコード例が書かれています。こちらもとてもわかり易いです。

PyTorchのモデルとの連携

これまで説明したことからすると、推論を行う関数を用意できるなら任意のモデルと連携するロジックが組めそうです。ただPyTorchのモデルで扱える型のデータを入力として与えられるかが心配ですね。例えば画像を入力とするモデルの場合どうすればよいのでしょう?

ですがその心配はいりません。https://gradio.app/image_classification_in_pytorch/#step-3-creating-a-gradio-interfaceにある通り、画像データをPillowのImage型に変換したものをGradioから推論用関数に渡すことができるからです。Image型ならPyTorchのTensorに変換できますので、画像を問題なくPyTorchのモデルに渡せます。

同様にTFやKerasのモデルにも画像データを渡すことができます。https://gradio.app/image_classification_in_tensorflow/に実装例があります。

ウィジェット一覧

公式のAPI仕様書の"Component"の節に、入出力に使えるウィジェットが列挙されています。動くサンプルもあるので、おすすめです。

その他の機能

使用頻度は高くないかもしれませんが、「こんなこともできる!」という機能例をいくつか紹介します。

地図アプリとの連携

Plotlyの地図描画機能をGradioで使う例も紹介されています。

コードを見る限りでは、地図に限らずPlotlyのグラフ全般を同じ仕組みで描画できそうです。
実際にHuggingFaceのWebサイトには、GradioでPlotlyの散布図を描画する例が示されています。

見た目のカスタマイズ

細かいカスタマイズはできませんが、一部のウィジェットの見た目をCSSで変えることはできます。

終わりに

Gradioのウィジェットの見た目をそのまま流用してよければ、Gradioによってかなり自由度の高いWeb UIを作ることができます。
例えばStable Diffusionの代表的なGUI環境であるAUTOMATIC1111/stable-diffusion-webuiもGradioで作られており、かなり複雑なUIも実現可能であることがわかります。

精査はしていませんが、HuggingFaceなどのWebサイトでも組み込まれて活用されていることからセキュリティもある程度担保されていると予想されます。
ライセンスもApache License 2.0ですので、当該ライセンスに基づくコードを組み込んでいるとの文言を添付すれば商用でも利用可能です。

そのため例えば、お客さまに触っていただくプロトタイプをAIベンダが作るのにもGradioが活用できる余地は十二分にあると思います。

新規登録して、もっと便利にQiitaを使ってみよう

  1. あなたにマッチした記事をお届けします
  2. 便利な情報をあとで効率的に読み返せます
ログインすると使える機能について

コメント

この記事にコメントはありません。
あなたもコメントしてみませんか :)
新規登録
すでにアカウントを持っている方はログイン
6