見出し画像

Google Colab で Janus-Pro を試す

「Google Colab」で「Janus-Pro」を試したのでまとめました。

【注意】「Google Colab Pro/Pro+」のA100で動作確認しています。

1. Janus

1-1. Janus-Pro

Janus-Pro」は、前作「Janus」の進化版です。具体的には、「Janus-Pro」には (1) 最適化された学習戦略、(2) 拡張された学習データ、(3) より大きなモデルサイズへのスケーリングが組み込まれています。これらの改良により、「Janus-Pro」は、「マルチモーダル理解」と「Text-to-Image」への指示追従機能の両方で大幅な進歩を達成し、「Text-toto-Image」の生成の安定性も向上しています。

画像

1-2. Janus

Janus」は、マルチモーダル理解と生成を統合する新しい自己回帰フレームワークです。処理には単一の統合されたTransformerアーキテクチャを使用しながら、視覚エンコーディングを別々の経路に分離することで、従来のアプローチの限界に対処します。分離により、理解と生成における視覚エンコーダーの役割の競合が軽減されるだけでなく、フレームワークの柔軟性も向上します。「Janus」は、以前の統合モデルを凌駕し、タスク固有のモデルの性能に匹敵するか、それを上回ります。「Janus」のシンプルさ、高い柔軟性、有効性により、次世代の統合マルチモーダルモデルの有力な候補となっています。

画像

1-3. JanusFlow

JanusFlowは、自己回帰言語モデルと、生成モデリングの最先端の手法である修正フローを統合した最小限のアーキテクチャを導入します。主要な発見は、修正フローが大規模な言語モデルフレームワーク内で簡単に学習でき、複雑なアーキテクチャの変更が不要であることを示しています。広範な実験により、「JanusFlow」は、それぞれのドメインの専門モデルと同等またはそれ以上の性能を実現し、標準ベンチマーク全体で既存の統合アプローチを大幅に上回るパフォーマンスを発揮することが示されています。この研究は、より効率的で多用途な視覚言語モデルへの一歩を示しています。

画像

2. モデル

「Janus」のモデルは、次の4つが提供されています。

deepseek-ai/Janus-Pro-7B
deepseek-ai/Janus-Pro-1B
deepseek-ai/JanusFlow-1.3B
deepseek-ai/Janus-1.3B

3. 画像生成

「Google Colab」での画像生成の実行手順は、次のとおりです。

(1) パッケージのインストール。

# パッケージのインストール
!git clone https://github.com/deepseek-ai/Janus
%cd Janus
!pip install -e .

(2) 画像生成。
今回は、「Janus-Pro-7B」を使います。

import os
import PIL.Image
import torch
import numpy as np
from transformers import AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor


# specify the path to the model
model_path = "deepseek-ai/Janus-Pro-7B"
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
    model_path, trust_remote_code=True
)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()

conversation = [
    {
        "role": "<|User|>",
        "content": "A stunning princess from japan in red, white traditional clothing, black eyes, black hair of japanese anime style",
    },
    {"role": "<|Assistant|>", "content": ""},
]

sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
    conversations=conversation,
    sft_format=vl_chat_processor.sft_format,
    system_prompt="",
)
prompt = sft_format + vl_chat_processor.image_start_tag


@torch.inference_mode()
def generate(
    mmgpt: MultiModalityCausalLM,
    vl_chat_processor: VLChatProcessor,
    prompt: str,
    temperature: float = 1,
    parallel_size: int = 16,
    cfg_weight: float = 5,
    image_token_num_per_image: int = 576,
    img_size: int = 384,
    patch_size: int = 16,
):
    input_ids = vl_chat_processor.tokenizer.encode(prompt)
    input_ids = torch.LongTensor(input_ids)

    tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int).cuda()
    for i in range(parallel_size*2):
        tokens[i, :] = input_ids
        if i % 2 != 0:
            tokens[i, 1:-1] = vl_chat_processor.pad_id

    inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)

    generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()

    for i in range(image_token_num_per_image):
        outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
        hidden_states = outputs.last_hidden_state
        
        logits = mmgpt.gen_head(hidden_states[:, -1, :])
        logit_cond = logits[0::2, :]
        logit_uncond = logits[1::2, :]
        
        logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
        probs = torch.softmax(logits / temperature, dim=-1)

        next_token = torch.multinomial(probs, num_samples=1)
        generated_tokens[:, i] = next_token.squeeze(dim=-1)

        next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
        img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
        inputs_embeds = img_embeds.unsqueeze(dim=1)


    dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
    dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)

    dec = np.clip((dec + 1) / 2 * 255, 0, 255)

    visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
    visual_img[:, :, :] = dec

    os.makedirs('generated_samples', exist_ok=True)
    for i in range(parallel_size):
        save_path = os.path.join('generated_samples', "img_{}.jpg".format(i))
        PIL.Image.fromarray(visual_img[i]).save(save_path)


generate(
    vl_gpt,
    vl_chat_processor,
    prompt,
)

A stunning princess from japan in red, white traditional clothing, black eyes, black hair of japanese anime style

【翻訳】
日本のアニメ風の赤と白の伝統衣装、黒い目、黒い髪をまとった、日本から来た見事なお姫様

「Janus/generate_samples」の下に16枚の画像が生成されています。

画像

消費メモリは次のとおりです。

画像



いいなと思ったら応援しよう!

ピックアップされています

自然言語処理入門

  • 931本

コメント

ログイン または 会員登録 するとコメントできます。
あなたも書ける! note、はじめよう
プログラマー。iPhone / Android / Unity / ROS / AI / AR / VR / RasPi / ロボット / ガジェット。年2冊ペースで技術書を執筆。アニソン / カラオケ / ギター / 猫 twitter : @npaka123
Google Colab で Janus-Pro を試す|npaka
word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word word

mmMwWLliI0fiflO&1
mmMwWLliI0fiflO&1
mmMwWLliI0fiflO&1
mmMwWLliI0fiflO&1
mmMwWLliI0fiflO&1
mmMwWLliI0fiflO&1
mmMwWLliI0fiflO&1