見出し画像

AI を自分好みに調整できる、追加学習まとめ ( その5: LoRA)

 こんにちはこんばんは、teftef です。今回も追加学習手法についてです。これまで説明してきた Diffusion Model のファインチューニングでは一般的に Unet , Text Transformer の再学習を行いました。しかし、全てのパラメーターを再学習するには時間がかかってしまいます。今回はファインチューニング後のモデルの品質を下げず、省時間、省メモリの手法を実現した軽量化手法、『LoRA』 です。
 といっても今回は数学的な話はあまりせず、実際に 普通のファインチューニング手法と LoRA を用いた手法の違いを比べてみたいと思います。
 私もまだ初学者であり、説明が間違っていたり勘違いがある可能性が 0 ではないということをご了承ください。ぜひコメントなどをいただけたら幸いです。
それでは行きます。

使用した論文

問題提起

 機械学習(特に自然言語処理の分野)ではパラメーター数が大きく、ファインチューニングで全てのパラメーターを再学習するために多くの時間とコストを費やします。例えば GPT-2 では15億ほどのパラメーターがありそれらをすべて再学習させるのにはかなりの時間がかかります。さらに GPT-3 ではパラメーター数が1750億に上り、これを全て再学習させるのは現実的とは言えません。そこで多くの既存の手法として、パラメーターの一部のみを学習させることや、何か外部のモジュールによって学習(HyperNetworks のような)させるというものがありますが、これらは省時間、省コストをしている代わりに、モデルの品質低下が起こってしまいます。(多くの場合、学習効率と品質の間はトレードオフの関係である)。このような問題に対処するためにいくつか( 大きく2 つ 、① : Transformer に Adaptation 層を追加する、② : 入力のLayer activision の最適化(?))の解決策が出ており、今回はその中でも「Transformer に Adaptation 層を追加する」という手法について焦点を当てます。


「Transformer に Adaptation 層を追加する」


  今回の手法 LoRA (Low-Rank Adaptation) では Transformer の層ごとに学習可能なランク分解行列(パラメーター)を挿入します。この新しく追加したパラメーターを学習することによって学習するパラメーター数を1万分の1にすることができ、必要な VRAM を3分の1にできます。これによって大幅な省時間、省メモリが可能となます。さらにファインチューニングしたときのモデルの品質が LoRA を使わなかった場合に比べて大きく劣化しないというのが特徴です。

少し詳しく

 実際に 自然言語モデルをファインチューニングするとき、勾配法を用います。学習途中では以下の式を繰り返し使い、条件を達成するためのパラメーター Φ をいい感じに調整する(最適化)ことが目標です。最終的に初期モデルのパラメーター Φ_0 にパラメーター増分(差分) ΔΦ を追加し、Φ_0 + ΔΦ となります。

画像

 しかしこれでは、勾配法の各ステップにおいて、全てのパラメーターを調整する必要があるため、実現が困難となってしまいます。そのため、今回はパラメーター増分(差分) ΔΦ をより小さいサイズのパラメーターで符号化します。つまりモデルのパラメーターを調整するパラメーター(Θ)を用意します。パラメーター増分(差分) ΔΦ を見つけるタスクはパラメーター(Θ)を最適化するタスクとなります。式で表すと以下の通りです。

画像

 これによって、比較的少ないパラメーター(Θ)で学習をすることができるようになりました。もとのモデルの 1% 未満のパラメーター数になることもあり、大幅な時間短縮がされたように見えます。

レイテンシーの増加

 しかし、ただ「Transformer に Adaptation 層を追加する」だけでは推論時の入力から出力を出すまでの処理に時間がかかりすぎてしまうことがあります(レイテンシー)。というのも機械学習では大量ののパラメーターを扱うため GPU を使用した並列処理がなされています。これによって(全部 CPU で計算するより)比較的短時間で処理ができます。しかし Adaptation 層 は並列処理ではなく逐次処理のためいくら並列処理が早く終わってもこの Adaptation 層 の処理を待たなくてはいけないため、結果的にレイテンシーの増加につながってしまいます。

LoRA (Low-Rank Adaptation) 

それでは具体的に LoRA がやっていることを見ていきます。

画像

 学習済みモデルの重み W_0 ( d × k 行列)を固定し、学習されないようにします。続いてパラメーター増分(差分) ΔW を行列 BA に分解します。ここで B : d × r 行列、A : r × k 行列とします。もちろん r は d,k より小さい次元であり、これによってパラメーターを r 次元まで押し込んでいます。

画像

 LoRAではこの r 次元の部分を学習するため、パラメーター数が減り、省時間、省メモリを実現しています。また行列 BA を学習することで、全てのパラメーターが学習された時と同等の表現力を維持することができます。それに加え、今回は元の重み  W_0 にパラメーター増分(差分) ΔW = BA を追加しているます。元の重み W_0 は変更しないので考える必要がありません。そのため学習時はこの差分のみをロードすればよいので高速な処理をすることができます。
  これは推論時でもパラメーター増分(差分) ΔW = BA のみを使えばよいので、メモリへのアクセス時間 (メモリオーバーヘッド) が少なく済みます。 Adaptation 層 は並列処理ではなく逐次処理であるという問題は依然解決していませんが、少なくとも大幅な時間短縮となります。

結果

  • 元の重み W_0 は変更しないので (差分のみ学習するため)、省メモリを実現し、VRAM 使用量を 2/3 以上削減できました。GPT-3 モデルでは VRAM 消費量を 1.2 TB から 350 GB まで削減できたそうです。

  • 学習する重み行列(パラメーター) の rank (次元)を r 次元まで下げたことで学習する量が少なくなり、省時間、省メモリとなりました。GPT-3 モデルでは r =4 とすると重みを 10000 倍以上削減することができました。

  • 差分のみを学習しているので、元の学習済みモデルにこの差分を足すだけで、ファインチューニング済みモデルが使用可能となるため、複数の差分を切り替えることが簡単になっています。

実験

 それでは実際に LoRA を使用してファインチューニングしたときの結果と必要なマシンスペックを比較してみました。

データセット

 今回使用したデータセットはこのような感じの画像を 2200 枚
Prompt などはまとめてあるので「悪用しない限り」、ご自由に使用してください。

画像
画像
画像
画像

データセットの集め方はこちらから ↓

普通にファインチューニング

[条件]

・GPU : Tesla T4,  VRAM : 16GB
・学習済みモデル : 
Waifu Diffusion 1.3 full   wd-v1-3-full.ckpt (7.7 GB)
・追加学習用画像 : 2200 枚
・Step : 5000
・batch_size : 4
・learning_rate : 5e-6

 Google Colab のGPU使用上限に達したため 、5000 Step の学習となっています。かかった時間は batch_size 4 , 5000 Step で 2.5 時間ほど。以下にモデルを載せておきます。

LoRA を使ったファインチューニング

[条件]

・GPU : RTX2070 SUPER , VRAM : 8 GB
・学習済みモデル : 
Waifu Diffusion 1.4   wd-1-4-anime_e1.ckpt full (5.16 GB)
・追加学習用画像 : 2200 枚
・Step : 50000
・batch_size : 1
・learning_rate : 1e-4

 自分の GPU を使って学習を行いました。条件はそろえたかったのですが、 batch_size 1 , 50000 Step とかなりの量を学習してもかかった時間は 5 時間ほどでした。以下にモデルを載せておきます。

生成された画像

 左から、 Waifu Diffusion 1.4 , 普通にファインチューニングしたモデルLoRA を使ったファインチューニング
になっています。

・Seed は行ごとに固定
・Prompt : a girl,
・Negative Prompt : low quality,bad face,bad hands,
・Step : 28
・CFG Scale : 7.5
・Sampling method : Euler a
・w*h=512*768

画像
seed : 200
画像
seed : 300
画像
seed : 2735

少し Prompt を変更

画像
 Prompt : agirl, beautiful face, 
seed : 952
画像
Prompt :masterpiece, seed : 448248
画像
Prompt :masterpiece, seed : 64882

 普通にファインチューニングした場合 (真ん中) は 5000 Step しか学習していない+WD1.3 ベースなので顔が崩れたり、背景が崩れてしまっています。
 それに対して LoRA を使用したファインチューニングでは 50000 Step 学習したため、ベースの WD1.4 と比較しても大きく品質が劣ることはなく、学習させた「ゆめかわ」風の画像が出力されています。

結論

 今回はこのようにベースとなるモデルが違ったためこのように品質に差が出てしまいましたが、LoRA の強みは省 VRAM 、省時間で学習できることです。これにより自宅の PC でも GPU が RTX 2070 , 3060 , 3090 でも学習ができるようになります。また LoRA によって学習されても大きな品質劣化を起こさないというのが大きな強みとなっています。実装までに少し時間がかかるかと思いますが、kohya さんの Github に説明が詳しく載っているので、とても分かりやすかったです。LoRA を使った DreamBooth 手法もあるのでそちらも併せてお読みください。

LoRA  を "試す"

Google Colab での実装はこちらから

参考文献

次回予告と宣伝

 今回は LoRA を論文をもとににまとめ、実装してみた結果をまとめてみました。次回の内容はまだ決まってませんが、おそらく追加学習手法は今回が最後で、次回は別の話になるかと思います。
 今回作成したモデルはこちらに公開しています。SD2.0 ベースなので yaml ファイルのダウンロードも忘れずに。

最後に

 最後まで読んでいただきありがとうございました。最後に少し宣伝です。主のteftefが運営を行っているdiscordサーバーを載せます。このサーバーではMidjourneyやStble Diffusionのプロンプトを共有したり、研究したりしています。ぜひ参加して、お絵描きAIを探ってみてはいかがでしょう。
 質問、リクエスト、バグ報告もこちらで行っています。
(teftef)

いいなと思ったら応援しよう!

ピックアップされています

論文解説まとめ

  • 30本

コメント

ログイン または 会員登録 するとコメントできます。
お年玉ポイントキャンペーン noteで記事を買うと 抽選で最大全額戻ってくる 1/9(木)まで 条件・上限あり
北の方の大学の M2 になれました。趣味でアニメーションと写真レタッチなどをやっています。この記事では幅広いジャンルで備忘録のようなものを書いていくつもりです!
AI を自分好みに調整できる、追加学習まとめ ( その5: LoRA)|teftef
word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word

mmMwWLliI0fiflO&1
mmMwWLliI0fiflO&1
mmMwWLliI0fiflO&1
mmMwWLliI0fiflO&1
mmMwWLliI0fiflO&1
mmMwWLliI0fiflO&1
mmMwWLliI0fiflO&1