交替直接差分学習法ADDifT(Alternating Direct Difference Training)の解説

 LoRA学習の技術は日進月歩ですが、その根本となる学習手法は初期の段階からほとんど変わっていません。今回は新しい学習手法を開発したので解説します。この技術ではコピー機学習で行っていた差分学習の所要時間を10~20分の1に短縮し、複数の画像セットの同時学習も可能とします。画像モデルだけでなく他の拡散モデル全般に応用できる可能性があります。

LoRA学習とコピー機学習法 

 まずはふつうの学習手法についておさらいしてみましょう。

画像
伝統的な学習方法

 ふつうのLoRA(モデルも同じ)学習過程では、ノイズを加えられた画像 ( x_t ) を U-Net (最近のモデルは DiT だったりしますが) に通して出てきた予測ノイズ ( \hat{\epsilon}_\theta(x_t, t) ) と、本来のノイズ ( \epsilon ) を比較します。LをLossとすると、

L = \mathbb{E}{x, \epsilon, t} \big[ | \epsilon - \hat{\epsilon}\theta(x_t, t) |^2 \big]

となりますが、ノイズを扱っているので一見分かりにくいですね。これは拡散モデルの仕組み上、ノイズを徐々に減らしていくことで画像を再構築するためです。ここでは予測ノイズ ( \hat{\epsilon}_\theta(x_t, t) ) が元のノイズ ( \epsilon ) と一致するように学習が行われます。わかりやすく言い換えると、Unetを通して出てきた画像が入力画像とどれだけ似ているかを比較しているということと同義です。学習しているのだから当たり前ですね。

 ここで問題になるのは学習するときに、キャラクターの特徴だけを学習してもらいたいのに画像に付随する要素まで学習されてしまうということです。画像ではピースサインをしていますが、キャラクター=ピースサインをしている人物という所まで学習されてしまうわけですね。

 あるいは背景まで学習されてしまう、画風まで学習されてしまうと言うことが起こりえます。タグ付けや多くの画像を学習させることによってある程度は軽減できますが、逆に学習してもらいたい要素が抜け落ちたりしてしまうわけです。

 さて、このような問題を回避する方法としてコピー機学習法があります。月須和・那々氏が考案した学習方法で、ふたつのLoRAを学習するところが特徴です。コピー機学習法ではLoRAを学習する時に一枚だけの画像を極限まで学習させ、どのような初期ノイズにおいても同じ画像しか出さないモデルを作ります。その後、画像を少し改変した画像をつくり、別なLoRAを学習すると、同じ画像しか出さないモデルが少し改変した画像を学習することになるので、画像間の変化のみを学習できるという方法です。詳細はエマノンさんがわかりやすい記事を書いているのでそちらを参照してください。

 コピー機学習法は大変強力な手法で、これまで実現できなかった一部分のみを改変するLoRAを作ることができるようになりました。

 新しい手法は学習することはコピー機学習法と同じですが、学習方法が異なります。この手法では2枚1組の画像( x_1 ), ( x_2 )を使います。交替直接差分学習法・ADDifT(Alternating Direct Difference Training)の「直接」は差分を直接LoRAに学習させると言うことを意味します。

 目を閉じた・開けた差分を学習する事を考えましょう。ADDifTでは学習対象の LoRA のパラメータ更新は以下の損失関数に従います。\theta^+(x_2, t)はLoRAを適用した状態で目を閉じた画像を推論した結果。\theta(x_1, t)は目を開けた画像を推論した結果。

L_{\text{diff}} = \mathbb{E}{x_1, x_2, \epsilon, t} \big[ | (\hat{\epsilon}\theta^+(x_2, t) - \hat{\epsilon}\theta(x_1, t) |^2 \big]

 つまり、従来の学習が「ある画像そのものを生成する」ことを目的にしているのに対し、ADDifT は「ある画像の変化部分のみを学習する」ことになります。

画像
ADDifT

 この手法ではふたつの画像の差分そのものをLoRAに学習させます。つまり、いったんコピー機を作る手間がなくなるわけですね。また、差分を直接学習するので学習に必要なステップも30~100と少なくて済みます。というか100 step以上学習すると過学習になりよろしくありません。

交替学習(Alternating Training)

 では「交替」は何かというと、攻守を交替しながら学習を行うことを意味します。直接差分を学習させるだけではうまくいかないので導入しました。差分なので理想的には目の部分だけを学習してほしいところですが、ノイズ予測の段階で差異が出てしまい、目以外の部分も学習されてしまうようです。
 その効果をキャンセルするために逆の課程の学習も行っているのです。つまり、「目を開ける→目を閉じる」の方向と、「目を閉じる→目を開ける」の学習を交互に行っているのです。ここで、「目を閉じる→目を開ける」の学習を行うときにはLoRAの適応率をマイナスにしています。これで目の部分以外をキャンセルしつつ、差分だけを学習できるようになるのです。

\begin{array}{} L_{\text{diff}+} &=& \mathbb{E}{x_1, x_2, \epsilon, t} \big[ | (\hat{\epsilon}\theta^+(x_2, t) - \hat{\epsilon}\theta(x_1, t) |^2 \big] \\ L_{\text{diff}-} &=& \mathbb{E}{x_1, x_2, \epsilon, t} \big[ | (\hat{\epsilon}\theta^-(x_1, t) - \hat{\epsilon}\theta(x_2, t) |^2 \big] \end{array}

Scheduled Random Timesteps

 新しい概念としてScheduled Random Timestepsというのも導入しています。この学習では学習ステップが30~100と極端に少なくなります。これまでと同じ方法で学習を行うと起きてしまうのがTimestepsの偏りです。学習時、Timestepsは0~1000までの値がランダムに決められます。このTimestepsは均等に学習されることが理想ですが、実際には通常の学習回数の条件下ではそうなりません。総学習回数(step x batch size)が1000の場合にはそれなりに均等に配分されますが、30~100の場合には偏りがでてしまいます。つまり、同じ画像、同じ設定であるのに学習結果が全く異なるという事が起きえるわけです。そこで、Timestepsの選択をある程度絞るのがScheduled Random Timestepsです。この方法では0~1000のTimestepsを5分割し順番にその中からランダムに選ぶという処理になっています。つまり、
0~200の中から選ぶ、201~400の中から選ぶ、・・・、800~1000の中から選ぶということをやっています。

 実際にはさらにTimestepsを絞っていて500~1000を5分割しています。これは後述しますが目を閉じるなどのLoRAにおいては高Tiemstepsで学習することで学習が安定するからです。一方で画風LoRAになると低Timestepsで安定します。

 実際のLossの推移は以下のようになります。高TimestepsほどLossが小さくなる訳ですが、これを平らにした方がいいような気もするし、やめておいた方がいいような気もする所です。とりあえず何もしなくてもうまく動いているので今のところはこのままです。

画像
Loss

 これにより、従来(500+500 steps) × バッチサイズ 2 = 2000 かかっていた学習を30 steps × バッチサイズ 1 = 30まで減らすことに成功しています。SD1.5では目を閉じるLoRAの作成にかかった時間は30秒です(3060 12GBを使用)。でもそこまで早くて性能が悪くては意味がありません。

画像
比較。上段がLoRA無し、中段がコピー機学習、下段がADDifT

 これはコピー機学習で作成したLoRAとの比較ですが、遜色ないものが出来ていることが見て取れます。

 次に画風LoRAも試してみましょう。画風LoRAでは学習対象のTimestepsを200-400に設定しています。入力画像はHires-Fix前後の画像を学習対象とします。512x512の画像を3倍でHires-Fixし、512x512に縮小して比較します。これで描き込みが多くなる変化を学習できることが期待されるわけですね。

画像
入力画像。Hires-Fixの前後を使用
画像
ディテールを調整するLoRA

 ちゃんと出来ていますね。目を閉じるなどのLoRAなら簡単なのですが、画風LoRAの場合には学習対象の画像の用意が難しく、試行錯誤しなければいけないわけですが、コピー機学習の場合にはできあがるまでに10~20分かかってしまってなかなか試行錯誤できなかった訳ですが、1分程度でできるなら試行回数を稼げますね。

 さて、ADDifTには早くなる以外にもメリットがあります。それは複数画像の学習が行えるということです。コピー機学習では一セットだけしか学習できませんでした。しかし、それでは汎用性に問題が出てしまいます。例えば目を閉じる/開ける場合には、正面の画像だけの学習では横を向いた時や顔を傾げた時などには適用度合いが下がってしまうわけです。これに対応するには複数セットで学習を行い後にマージするなどが考えられますが手間がかかります。ADDifTでは画像セットを変えながら学習を行うことができるので、より汎用性の高い学習が高速に行えるわけですね。

Timestepsの設定

 ADDifTの学習においてはTiemstepsを限定することが大切になります。
通常、学習においてはTimestepsはまんべんなく学習することが良しとされていますが、ADDifTにおいては学習する対象によって限定した方が学習が安定します。目を閉じる/開けるのような動作の場合にはTiemstepsを500~1000にします。逆にディテールを追加するような画風LoRAの場合にはTiemstepsを200~400にし、学習対象と元画像を逆にすると順方向に学習が行われます。500~1000を入れると学習が正常に進みません。なぜかはわからないのですが、Timestepsが小さい領域だと学習が逆方向に進んでしまうようなのです。

 例ではSD1.5を出していますが、XLでも動作する事を確認しています。試しに作ったディテールを追加するLoRAをCivitAIに上げておきました。FluxやSD3.5などのDiTを使った新しいモデルに対して有効化はまだ試してません。というか3060で動くのか怪しいので5000シリーズの供給が落ち着いてきたら試してみるつもりです。

今後の発展

ノイズ予測同士を比較するという方法はiLECOでも使っているわけですが、これは通常のLoRA学習にも応用できるのではないかと考えています。自己正則化に使えるのではないかと。DreamBooth方式では正則化と呼ばれるLoRAの過学習を抑える仕組みが提案されています。これは学習対象の画像とはことなる一般的な画像をまぜて学習することで過学習などを抑制するものです。
 例えばキャラクターを学習させるときに、学習対象のキャラクター名とキャラクターの性別や体格、髪型、画像の背景などをプロンプトとして入力します。しかし、本来キャラクター名に紐付けられるべき情報が他のプロンプトにも学習されてしまうことが起きるわけですね。性別を入力しただけでそのキャラクターが出てしまう事がおこるわけです。
 正則化ではそのキャラクター以外の画像を性別などのプロンプトを入力して学習することでキャラクター名以外の学習を薄めることを目的としているわけですが、画風などさらに余計なことまで学習してしまうなどの副作用も強く、いまではほとんど使われていません。
 そこでノイズ予測同士の比較が生きてきます。LoRAを適用してキャラクター名以外のプロンプトで生成した予測と、LoRAを適用せずに同じプロンプトで生成したノイズ予測を比較する事で、悪影響のない正則化が可能ではないかというわけです。
 このように、ノイズ予測同士比較はまだまだ様々な可能性を秘めているのではないかと考えています。Diffusionモデル全般に応用できるのではないかと。

まとめ

 ADDifTは画像生成モデルの新しい学習手法としてふたつの画像の差分を高速に学習する手段となります。これはDiffusionモデル全般に適用できる可能性があり、動画やその他のモデルにも使えるのではないかと考えています。


いいなと思ったら応援しよう!

コメント

ログイン または 会員登録 するとコメントできます。
交替直接差分学習法ADDifT(Alternating Direct Difference Training)の解説|hakomikan
word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word

mmMwWLliI0fiflO&1
mmMwWLliI0fiflO&1
mmMwWLliI0fiflO&1
mmMwWLliI0fiflO&1
mmMwWLliI0fiflO&1
mmMwWLliI0fiflO&1
mmMwWLliI0fiflO&1