17

この記事は最終更新日から1年以上が経過しています。

Organization

Huggingface Transformersのモデルをオフラインで利用する

Huggingface Transformersは、例えばGPTのrinnaのモデルなどを指定することでインターネットからモデルをダウンロードして利用できます。

ある時HuggingfaceのWebサイトが落ちていて、Transformersを利用したプログラムが動かなくなった事がありました。しかし通常の動作ではモデルのデータはキャッシュされているはずで、一度ダウンロードされていたらオフラインでも動作するのではないか?と思いました。

本稿では、Transformersのコードが一時的にインターネットに接続できないオフライン状態になっても動作できる方法を記載します。

環境

  • Windows 11 WSL2 Debian Bullseye
  • Docker Python:3.9-slim
  • Transformers 4.16.2

実装

本稿ではGPTのrinnaのモデルで文章を生成する下記のプログラムを使います。ソースコードもほとんどサンプルコードそのままです。

import torch
from transformers import T5Tokenizer, AutoModelForCausalLM

class Generator():
    def __init__(self, model_name = "rinna/japanese-gpt-1b"):
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name)

    def gen(self, text: str, max_length=100):
        token_ids = self.tokenizer.encode(text, add_special_tokens=False, return_tensors="pt")

        with torch.no_grad():
            output_ids = self.model.generate(
                token_ids.to(self.model.device),
                max_length=max_length,
                do_sample=True,
                top_k=500,
                top_p=0.95,
                pad_token_id=self.tokenizer.pad_token_id,
                bos_token_id=self.tokenizer.bos_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                bad_word_ids=[[self.tokenizer.unk_token_id]]
            )
        output = self.tokenizer.decode(output_ids.tolist()[0])
        return output

とりあえず適当な文章を生成してみます。

generator = Generator(model_name="rinna/japanese-gpt2-xsmall")
generator.gen("吾輩は猫である。名前は")

吾輩は猫である。名前は家から出てたのだからどこかもかわいそうだ。そして、この子は生まれも育ちも真昼のところだったというが、本当のところはどうでもいい。その子ネコ猫の妖精はまさしくその妖精か、この妖精のようなタイプである。彼らについて真昼の寒さに耐えられるなあ、と思って、暗い影からずぼらして見たが、森の中どもで普通に歩いている人間を見た。そして、

うまく生成されました。

この時、初回実行時はfrom_pretrainedメソッド実行時にダウンロード処理が走ります。Jupyterで実行していたら下記のようなプログレスバーが表示されるはずです。

image.png

このダウンロードされたモデルは通常キャッシュされており、2回目以降はダウンロード処理をせずに高速で動作します。

ただし、ここでPCをインターネットから切断するなどしてダウンロードができない状態にして同じプログラムを実行すると、下記のようなエラーが発生します。

ValueError: Connection error, and we cannot find the requested files in the cached path. Please try again or make sure your Internet connection is on.

Huggingfaceのモデルデータのダウンロード

huggingface_hubsnapshot_download()を使って、一度モデルを明示的にダウンロードしてそのパスを指定することで、ローカルファイルとして実行できます。

# インターネットでダウンロードできる時に実行
from huggingface_hub import snapshot_download
download_path = snapshot_download(repo_id="rinna/japanese-gpt2-xsmall")
# オフラインで実行
generator = Generator(model_name=download_path)
generator.gen("吾輩は猫である。名前は")

このプログラムはインターネットから切断した状態でも動作します。

Transformersのオフラインモード

インターネットの接続有無で挙動が変化することが無いよう、Transformersをそもそもインターネットに接続せずに動作するようにできます。

環境変数TRANSFORMERS_OFFLINE1を設定するとTransformersはオフラインモードとなり、インターネットを通してのモデルのダウンロードを行わなくなります。この状態で、インターネットからのダウンロードが必要なモデルを指定するとエラーになります。

from os import environ
assert environ["TRANSFORMERS_OFFLINE"] == "1"
generator = Generator(model_name="rinna/japanese-gpt2-xsmall")
generator.gen("吾輩は猫である。名前は")

OSError: Can't load tokenizer for 'rinna/japanese-gpt2-xsmall'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure 'rinna/japanese-gpt2-xsmall' is the correct path to a directory containing all relevant tokenizer files.

なお先述のhuggingface_hub.snapshot_download()TRANSFORMERS_OFFLINE1でも利用できます。

ダウンロードできないときの挙動

キャッシュされているはずなのにダウンロードできない時エラーが出る理由ですが、キャッシュが存在する時もETagを確認しにHTTPリクエストを投げています。このETagをキャッシュのファイル名に使っており、HTTPリクエストが失敗した場合にETagが空文字列としてファイルを検索されるためキャッシュがヒットしないということになっていました。

(この時HTTPリクエストエラーを一旦例外とせずpassしてEtagを空文字にしてプログラムを進めているのですが、その理由はよく分かりませんでした)

参考文献

新規登録して、もっと便利にQiitaを使ってみよう

  1. あなたにマッチした記事をお届けします
  2. 便利な情報をあとで効率的に読み返せます
ログインすると使える機能について
17