「DPOを使用してLlama 2を微調整する」

Using DPO to fine-tune Llama 2

はじめに

人間のフィードバックからの強化学習(RLHF)は、GPT-4やクロードなどのLLMの最終トレーニングステップとして採用され、言語モデルの出力が人間の期待に合致するようにするために使われます。しかし、これによってRLの複雑さがNLPにもたらされます。良い報酬関数を構築し、モデルに状態の価値を推定するように訓練し、同時に元のモデルからあまり逸脱せずに意味のあるテキストを生成するように注意する必要があります。このようなプロセスは非常に複雑で、正しく行うのは常に簡単ではありません。

最近の論文「Direct Preference Optimization」(Rafailov、Sharma、Mitchell他)では、既存の方法で使用されるRLベースの目的を、シンプルなバイナリクロスエントロピー損失を直接最適化できる目的に変換することを提案しており、これによりLLMの改善プロセスが大幅に簡素化されます。

このブログ記事では、TRLライブラリで利用可能なDirect Preference Optimization(DPO)メソッドを紹介し、さまざまなスタックエクスチェンジポータルの質問に対するランク付けされた回答を含むスタックエクスチェンジのデータセットで最近のLlama v2 7Bパラメータモデルを微調整する方法を示します。

DPO vs PPO

人間の派生した好みをRLを通じて最適化する従来のモデルでは、補助的な報酬モデルを使用し、興味のあるモデルを微調整して、この報酬をRLの仕組みを利用して最大化するようにします。直感的には、報酬モデルを使用して最適化するモデルにフィードバックを提供し、ハイリワードのサンプルをより頻繁に生成し、ローリワードのサンプルをより少なく生成するようにします。同時に、フリーズされた参照モデルを使用して、生成物があまり逸脱せずに生成の多様性を維持し続けるようにします。これは通常、参照モデルを介した全報酬最大化の目的にKLペナルティを追加することで行われ、モデルが報酬モデルをごまかしたり利用したりすることを防ぐ役割を果たします。

DPOの定式化は、報酬モデリングのステップをバイパスし、報酬関数から最適なRLポリシーへの解析的なマッピングを使用して、言語モデルを好みのデータに最適化します。このマッピングは、与えられた報酬関数が与えられた好みのデータとどれだけ合致するかを直感的に測定します。したがって、DPOはRLHFの損失の最適解から始まり、変数の変換を介して参照モデルに対する損失を導出することで、参照モデルのみに対する損失を得ることができます。

したがって、この直接的な尤度目的は、報酬モデルやポテンシャルに煩雑なRLベースの最適化を必要とせずに最適化することができます。

TRLのトレーニング方法

前述のように、通常、RLHFパイプラインは次の異なるパーツで構成されています:

  1. 教師あり微調整(SFT)ステップ
  2. データに好みのラベルを付けるプロセス
  3. 好みのデータで報酬モデルをトレーニングする
  4. そして、RL最適化ステップ

TRLライブラリには、これらのパーツのためのヘルパーが付属していますが、DPOトレーニングでは報酬モデリングとRL(ステップ3と4)のタスクは必要ありません。代わりに、TRLのDPOTrainerにステップ2の好みのデータを提供する必要があります。このデータは非常に特定の形式を持ちます。具体的には、次の3つのキーを持つ辞書です:

  • prompt:テキスト生成の際にモデルに与えられるコンテキストプロンプトです
  • chosen:対応するプロンプトに対して選ばれた生成された応答を含みます
  • rejected:対応するプロンプトに対して望ましくない生成された応答を含みます

例えば、スタックエクスチェンジの好みのペアデータセットでは、データセットのエントリを次のヘルパー関数を使用して目的の辞書にマッピングし、元の列はすべて削除します:

def return_prompt_and_responses(samples) -> Dict[str, str, str]:
    return {
        "prompt": [
            "Question: " + question + "\n\nAnswer: "
            for question in samples["question"]
        ],
        "chosen": samples["response_j"],   # rated better than k
        "rejected": samples["response_k"], # rated worse than j
    }

dataset = load_dataset(
    "lvwerra/stack-exchange-paired",
    split="train",
    data_dir="data/rl"
)
original_columns = dataset.column_names

dataset.map(
    return_prompt_and_responses,
    batched=True,
    remove_columns=original_columns
)

データセットがソートされたら、DPO損失は基本的には教師あり損失であり、参照モデルを介して暗黙の報酬を得ます。そのため、DPOTrainerは、最適化したいベースモデルと参照モデルを必要とします:

dpo_trainer = DPOTrainer(
    model,                 # SFTパイプラインからのベースモデル
    model_ref,             # 通常はSFTでトレーニングされたベースモデルのコピー
    beta=0.1,              # DPOの温度ハイパーパラメータ
    train_dataset=dataset, # 上記で準備したデータセット
    tokenizer=tokenizer,   # トークナイザー
    args=training_args,    # トレーニング引数、バッチサイズ、学習率など
)

betaハイパーパラメータはDPO損失の温度パラメータであり、通常は0.1から0.5の範囲です。これにより、参照モデルにどれだけ注意を払うかを制御します。 betaが小さくなるほど、参照モデルを無視するようになります。トレーナーを初期化したら、指定されたtraining_argsでデータセットを使用してトレーニングするため、単に次のように呼び出すことができます:

dpo_trainer.train()

Llama v2で実験する

TRLでDPOトレーナーを実装する利点は、TRLとその依存ライブラリであるPeftおよびAccelerateが付属する大規模LLMのトレーニングに関連するすべての追加の機能を利用できることです。これらのライブラリを使用すると、bitsandbytesライブラリが提供するQLoRAテクニックを使用してLlama v2モデルをトレーニングすることさえできます。

教師ありファインチューニング

上記で紹介したプロセスでは、TRLのSFTTrainerを使用してデータのSFTスプリット上の7B Llama v2モデルでQLoRAを使用した教師ありファインチューニングステップが含まれています:

# ベースモデルを4ビット量子化で読み込む
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    script_args.model_name,        # "meta-llama/Llama-2-7b-hf"
    quantization_config=bnb_config,
    device_map={"": 0},
    trust_remote_code=True,
    use_auth_token=True,
)
base_model.config.use_cache = False

# 量子化されたベースモデルの上にLoRAレイヤーを追加する
peft_config = LoraConfig(
    r=script_args.lora_r,
    lora_alpha=script_args.lora_alpha,
    lora_dropout=script_args.lora_dropout,
    target_modules=["q_proj", "v_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)
...
trainer = SFTTrainer(
    model=base_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    packing=True,
    max_seq_length=None,
    tokenizer=tokenizer,
    args=training_args,         # HF Trainerの引数
)
trainer.train()

DPOトレーニング

SFTが完了したら、結果のモデルを保存してDPOトレーニングに移ることができます。通常は、DPOのベースモデルと参照モデルの両方に前のSFTステップから保存されたモデルを使用します。次に、上記のstack-exchangeの好みデータでDPO目的関数を使用してモデルをトレーニングできます。LoRaアダプターを介してモデルがトレーニングされたため、PeftのAutoPeftModelForCausalLMヘルパーを使用してモデルを読み込みます:

model = AutoPeftModelForCausalLM.from_pretrained(
    script_args.model_name_or_path, # 保存されたSFTモデルの場所
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
    is_trainable=True,
)
model_ref = AutoPeftModelForCausalLM.from_pretrained(
    script_args.model_name_or_path,  # メインモデルと同じモデル
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)
...
dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=script_args.beta,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    peft_config=peft_config,
)
dpo_trainer.train()
dpo_trainer.save_model()

上記のように、モデルを4ビット構成でロードし、QLoraメソッドを使用してpeft_config引数を介してトレーニングします。トレーナーはまた、評価データセットに対するトレーニングの進捗状況を評価し、例えばWandBを介して記録および表示できる暗黙的な報酬などのいくつかの主要なメトリクスを報告します。その後、最終的にトレーニングされたモデルをHuggingFace Hubにプッシュできます。

結論

SFTとDPOのトレーニングスクリプトの完全なソースコードは、以下のexamples/stack_llama_2ディレクトリで利用可能であり、マージされたアダプタを備えたトレーニング済みモデルはこちらのHF Hubで見つけることができます。

DPOトレーニングランのWandBログは、こちらで見つけることができます。トレーニングと評価中に、DPOTrainerは以下の報酬メトリクスを記録します:

  • rewards/chosen:選択された応答のポリシーモデルと参照モデルの対数確率の平均の差をbetaでスケーリングしたもの
  • rewards/rejected:拒否された応答のポリシーモデルと参照モデルの対数確率の平均の差をbetaでスケーリングしたもの
  • rewards/accuracies:選択された報酬が対応する拒否された報酬よりも大きい頻度の平均
  • rewards/margins:選択された報酬と対応する拒否された報酬の平均の差

直感的には、トレーニング中にはマージンを増加させ、正確性を1.0にすることを望みます。つまり、選択された報酬が拒否された報酬よりも高くなる(またはマージンがゼロよりも大きくなる)ことです。これらのメトリクスは、評価データセット上で計算することができます。

私たちは、この方法を使って大規模な言語モデルを自分のデータセットに合わせる試みにおいて、コードの公開により読者の皆様の参入障壁が低くなることを望んでいます。そして、皆様が何を構築するかを楽しみにしています!また、モデルを自分で試してみたい場合は、こちらからできます:trl-lib/stack-llama。

We will continue to update VoAGI; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

人工知能

ジョシュ・フィースト、CogitoのCEO兼共同創業者 - インタビューシリーズ

ジョシュ・フィーストは、CogitoのCEO兼共同創業者であり、感情と会話AIを組み合わせた革新的なプラットフォームを提供するエ...

人工知能

「アナコンダのCEO兼共同創業者、ピーターウォングによるインタビューシリーズ」

ピーター・ワンはAnacondaのCEO兼共同創設者ですAnaconda(以前はContinuum Analyticsとして知られる)を設立する前は、ピー...

人工知能

ベイリー・カクスマー、ウォータールー大学の博士課程候補 - インタビューシリーズ

カツマー・ベイリーは、ウォータールー大学のコンピュータ科学学部の博士課程の候補者であり、アルバータ大学の新入教員です...

データサイエンス

「3つの質問:ロボットの認識とマッピングの研磨」

MIT LIDSのLuca CarloneさんとJonathan Howさんは、将来のロボットが環境をどのように知覚し、相互作用するかについて議論し...

データサイエンス

「Adam Ross Nelsonによる自信のあるデータサイエンスについて」

データサイエンスの中で新たな分野が現れ、研究内容が理解しにくい場合は、専門家や先駆者と話すのが最善です最近、私たちは...

人工知能

「ジャスティン・マクギル、Content at Scaleの創設者兼CEO - インタビューシリーズ」

ジャスティンは2008年以来、起業家、イノベーター、マーケターとして活動しています彼は15年以上にわたりSEOマーケティングを...