Sentence Transformersモデルのトレーニングと微調整

Sentence Transformersのモデルトレーニングと微調整

このNotebook Companion付きのチュートリアルをご覧ください:



センテンス変換モデルのトレーニングまたはファインチューニングは、利用可能なデータと目標のタスクに大きく依存します。キーは2つあります:

  1. モデルにデータを入力し、データセットを適切に準備する方法を理解する。
  2. データセットと関連する異なる損失関数を理解する。

このチュートリアルでは、以下の内容を学びます:

  1. “スクラッチ”から作成するか、Hugging Face Hubからファインチューニングすることにより、センテンス変換モデルの動作原理を理解する。
  2. データセットの異なる形式について学ぶ。
  3. データセットの形式に基づいて選択できる異なる損失関数について確認する。
  4. モデルのトレーニングまたはファインチューニング。
  5. Hugging Face Hubにモデルを共有する。
  6. センテンス変換モデルが最適な選択肢でない場合について学ぶ。

センテンス変換モデルの動作原理

センテンス変換モデルでは、可変長のテキスト(または画像ピクセル)を、その入力の意味を表す固定サイズの埋め込みにマップします。埋め込みの取得方法については、前回のチュートリアルをご覧ください。この投稿では、テキストに焦点を当てています。

センテンス変換モデルの動作原理は次の通りです:

  1. レイヤー1 – 入力テキストは、Hugging Face Hubから直接取得できる事前学習済みTransformerモデルを通過します。このチュートリアルでは、「distilroberta-base」モデルを使用します。Transformerの出力は、すべての入力トークンに対する文脈化された単語の埋め込みです。テキストの各トークンに対する埋め込みを想像してください。
  2. レイヤー2 – 埋め込みはプーリングレイヤーを通過して、テキスト全体の単一の固定長埋め込みを取得します。例えば、平均プーリングはモデルによって生成された埋め込みの平均値を計算します。

以下の図は、このプロセスを要約しています:

Sentence Transformersライブラリをpip install -U sentence-transformersでインストールすることを忘れないでください。コードでは、この2ステップのプロセスが次のように簡単になります:

from sentence_transformers import SentenceTransformer, models

## ステップ1:既存の言語モデルを使用する
word_embedding_model = models.Transformer('distilroberta-base')

## ステップ2:トークン埋め込みに対するプール関数を使用する
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())

## モジュール引数を使用してステップ1とステップ2を結合する
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

上記のコードからわかるように、センテンス変換モデルはモジュール(つまり、連続して実行されるレイヤーのリスト)で構成されています。入力テキストは最初のモジュールに入力され、最終的な出力は最後のコンポーネントから得られます。見た目はシンプルですが、上記のモデルはセンテンス変換モデルの典型的なアーキテクチャです。必要に応じて、追加のレイヤー(例:密なレイヤー、単語の袋、畳み込みなど)を追加することもできます。

なぜ、BERTやRobertaなどのTransformerモデルをそのまま使用して文やテキストの埋め込みを作成しないのでしょうか?少なくとも2つの理由があります。

  1. 事前学習済みのTransformerは、意味的な検索タスクを実行するために重い計算を必要とします。例えば、10,000の文のコレクションから最も類似したペアを見つける場合、BERTでは約5000万の推論計算(約65時間)が必要です。一方、BERTセンテンス変換モデルでは、その時間を約5秒に短縮できます。
  2. トレーニング済みのTransformerは、そのままでは文の表現が劣っています。BERTモデルは、トークン埋め込みを平均して文の埋め込みを作成する場合、2014年に開発されたGloVe埋め込みよりも性能が低下します。

このセクションでは、スクラッチからセンテンス変換モデルを作成しています。既存のセンテンス変換モデルをファインチューニングしたい場合は、上記の手順をスキップし、Hugging Face Hubからモデルをインポートすることができます。センテンス変換モデルのほとんどは、「Sentence Similarity(文の類似度)」タスクにあります。ここでは、「sentence-transformers/all-MiniLM-L6-v2」モデルをロードしています:

from sentence_transformers import SentenceTransformer

model_id = "sentence-transformers/all-MiniLM-L6-v2"
model = SentenceTransformer(model_id)

次に、最も重要な部分であるデータセットの形式について説明します。

センテンス変換モデルのトレーニングのためのデータセットの準備方法

センテンストランスフォーマーモデルをトレーニングするには、ある程度の類似性があるということを何らかの方法でモデルに伝える必要があります。したがって、データの各例には、2つの文が似ているか異なるかをモデルが理解するためにラベルまたは構造が必要です。

残念ながら、センテンストランスフォーマーモデルをトレーニングするためにデータを準備するための単一の方法はありません。それは主に目標とデータの構造に依存します。最もありそうなシナリオである明示的なラベルがない場合、文を取得したドキュメントの設計からそれを導出することができます。例えば、同じレポート内の2つの文は、異なるレポートの2つの文よりも比較可能であるべきです。隣接する文は、隣接しない文よりも比較可能かもしれません。

さらに、データの構造は使用できる損失関数に影響を与えます。これについては、次のセクションで説明します。

この記事のノートブックコンパニオンには、すでに実装済みのコードがすべて含まれていますので、覚えておいてください。

ほとんどのデータセットの構成は、以下の4つの形式のいずれかを取ります(以下にそれぞれの場合の例を示します):

  • Case 1:例は2つの文と、それらがどれだけ似ているかを示すラベルがあります。ラベルは整数または浮動小数点数であることができます。これは、自然言語推論(NLI)用に元々準備されたデータセットに適用されます。なぜなら、それらはお互いを推論するかどうかを示すラベルを持つ文のペアを含んでいるからです。
  • Case 2:ラベルのない正の(似ている)文のペアです。たとえば、類似文のペア、フルテキストとその要約のペア、重複した質問のペア、(queryresponse)のペア、または(source_languagetarget_language)のペアです。自然言語推論データセットも、推論文をペアリングすることでこのようにフォーマットすることができます。この形式でデータを持っていると、センテンストランスフォーマーモデルの最も使用される損失関数の1つであるMultipleNegativesRankingLossを使用することができます。
  • Case 3:整数ラベルを持つ文です。このデータ形式は、損失関数によって簡単に3つの文(三つ組)に変換されます。最初の文は「アンカー」であり、2番目の文はアンカーと同じクラスの「ポジティブ」であり、3番目の文は異なるクラスの「ネガティブ」です。各文には、それが所属するクラスを示す整数ラベルがあります。
  • Case 4:クラスまたは文のラベルがない三つ組(アンカー、ポジティブ、ネガティブ)です。

例えば、このチュートリアルでは、4番目のケースのデータセットを使用してセンテンストランスフォーマーをトレーニングします。次に、Notebook Companionを使用して、2番目のケースのデータセット構成でファインチューニングします。

センテンストランスフォーマーモデルは、人間のラベリング(ケース1とケース3)またはテキストのフォーマットから自動的に導き出されるラベル(主にケース2)でトレーニングすることができます(ケース4はラベルを必要としませんが、三つ組のデータを見つけるのはより難しいです。しかし、MegaBatchMarginLoss関数のように処理する場合は可能です)。

Hugging Face Hubには、上記の各ケースのデータセットがあります。また、Hubのデータセットには、ダウンロードする前にデータセットの構造を表示するためのデータセットプレビュー機能もあります。ここにそれぞれのケースのサンプルデータセットがあります:

  • Case 1:2つの文の間の類似度の度合いを示すラベル(たとえば{0,1,2}で0が矛盾、2が含意)を持つ場合、自然言語推論と同じセットアップを使用できます。SNLIデータセットの構造を確認してください。

  • Case 2:センテンス圧縮データセットは、正のペアで構成された例があります。データセットに2つ以上の正の文がある場合(たとえばCOCO CaptionsやFlickr30k Captionsデータセットのようなクインテットの場合)、例を異なる組み合わせの正のペアにフォーマットすることができます。

  • Case 3:TRECデータセットには、各文のクラスを示す整数ラベルがあります。Yahoo Answers Topicsデータセットの各例には、三つの文とそのトピックを示すラベルが含まれているため、各例を三つに分割することができます。

  • Case 4:Quora Tripletsデータセットには、ラベルのない三つ組(アンカー、ポジティブ、ネガティブ)があります。

次のステップは、センテンストランスフォーマーモデルが理解できる形式にデータセットを変換することです。モデルは生の文字列のリストを受け入れることはできません。各例はsentence_transformers.InputExampleクラスに変換され、それからtorch.utils.data.DataLoaderクラスに変換され、例をバッチ処理してシャッフルする必要があります。

Hugging Face Datasetsをpip install datasetsでインストールします。次にload_dataset関数を使ってデータセットをインポートします:

from datasets import load_dataset

dataset_id = "embedding-data/QQP_triplets"
dataset = load_dataset(dataset_id)

このガイドでは、ラベルのないトリプレットデータセットを使用します。上記の4番目のケースです。

datasetsライブラリを使ってデータセットを探索できます:

print(f"- {dataset_id}データセットには{dataset['train'].num_rows}の例があります。")
print(f"- 各例は値として{type(dataset['train'][0])}と{type(dataset['train'][0]['set'])}です。")
print(f"- 例は次のようになります:{dataset['train'][0]}")

出力:

- embedding-data/QQP_tripletsデータセットには101762の例があります。
- 各例はで、値はです。
- 例は次のようになります:{'set': {'query': 'Why in India do we not have one on one political debate as in USA?', 'pos': ['Why can't we have a public debate between politicians in India like the one in US?'], 'neg': ['Can people on Quora stop India Pakistan debate? We are sick and tired seeing this everyday in bulk?'...]

query(アンカー)は単一の文、pos(ポジティブ)は文のリスト(表示されるものは1つの文のみ)、neg(ネガティブ)は複数の文のリストです。

例をInputExampleに変換します。簡単のために、(1) embedding-data/QQP_tripletsデータセットのポジティブのうちの1つとネガティブのうちの1つのみを使用します。 (2) 使用可能な例の1/2のみ使用します。例の数を増やすと、より良い結果が得られます。

from sentence_transformers import InputExample

train_examples = []
train_data = dataset['train']['set']
# アジリティのために使用可能なデータの1/2のみ使用します
n_examples = dataset['train'].num_rows // 2

for i in range(n_examples):
  example = train_data[i]
  train_examples.append(InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]))

トレーニング例をDataloaderに変換します。

from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

次のステップは、データ形式に使用できる適切な損失関数を選択することです。

Sentence Transformersモデルのトレーニング用損失関数

データが持つ可能性のある4つの異なる形式を覚えていますか?それぞれには異なる損失関数が関連付けられています。

ケース1:2つの文と、それらがどれだけ似ているかを示すラベルがある場合。損失関数は、(1) ラベルが最も近い文がベクトル空間上で近くなるように最適化され、(2) ラベルが最も遠い文ができるだけ遠くなるように最適化されます。損失関数はラベルの形式に依存します。整数の場合はContrastiveLossまたはSoftmaxLossを使用し、浮動小数点数の場合はCosineSimilarityLossを使用できます。

ケース2:ラベルのない2つの類似した文(2つのポジティブ)のみがある場合、MultipleNegativesRankingLoss関数を使用できます。また、MegaBatchMarginLossを使用することもでき、その場合は例をトリプレット(anchor_i, positive_i, positive_j)に変換します。ここでpositive_jはネガティブとして機能します。

ケース3:サンプルが[anchor, positive, negative]の形式のトリプレットで、各々に整数のラベルがある場合、損失関数はアンカーとポジティブがベクトル空間上でより近くなるようにモデルを最適化します。アンカーとネガティブの間はできるだけ遠くなります。整数でラベル付けされたデータにはBatchHardTripletLossを使用できます。この場合、同じラベルを持つサンプルは類似していると仮定し、アンカーとポジティブには同じラベルが必要で、ネガティブには異なるラベルが必要です。また、BatchAllTripletLossBatchHardSoftMarginTripletLossBatchSemiHardTripletLossを使用することもできます。それらの違いはこのチュートリアルの範囲を超えていますが、Sentence Transformersのドキュメントで確認できます。

ケース4:トリプレット内の各文に対するラベルがない場合は、TripletLossを使用する必要があります。この損失関数は、アンカーとポジティブな文の間の距離を最小化し、アンカーとネガティブな文の間の距離を最大化します。

この図は、さまざまなタイプのデータセット形式、Hub内の例データセット、および適切な損失関数をまとめています。

概念的に適切な損失関数を選ぶことが最も困難な部分です。コードでは、次の2行しかありません:

from sentence_transformers import losses

train_loss = losses.TripletLoss(model=model)

データセットが所望の形式になり、適切な損失関数が準備されたら、Sentence Transformersの適合とトレーニングは簡単です。

センテンストランスフォーマーモデルのトレーニングまたはファインチューニングの方法

“SentenceTransformersは、独自の文/テキスト埋め込みモデルをファインチューニングすることを容易にするように設計されています。特定のタスクの埋め込みを調整するために組み合わせることができるほとんどのビルディングブロックを提供しています。” – Sentence Transformersのドキュメント。

トレーニングまたはファインチューニングは次のようになります:

model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=10)

既存のSentence Transformersモデルをファインチューニングする場合(Notebook Companionを参照)、直接fitメソッドを呼び出すことができます。これが新しいSentence Transformersモデルの場合は、”Sentence Transformersモデルの動作方法”セクションと同様に最初に定義する必要があります。

これで、新しいまたは改善されたSentence Transformersモデルができました!Hugging Face Hubに共有しますか?

まず、Hugging Face Hubにログインします。アカウント設定でwriteトークンを作成する必要があります。ログインするための2つのオプションがあります:

  1. ターミナルでhuggingface-cli loginと入力し、トークンを入力します。

  2. Pythonのノートブックの場合は、notebook_loginを使用できます。

from huggingface_hub import notebook_login

notebook_login()

次に、トレーニング済みモデルからsave_to_hubメソッドを呼び出してモデルを共有できます。デフォルトでは、モデルはアカウントにアップロードされますが、organizationパラメータを渡すことで組織にアップロードすることもできます。 save_to_hubは自動的にモデルカード、推論ウィジェット、コードスニペットの例などを生成します。 train_datasets引数を使用して、モデルのトレーニングに使用したデータセットのリストをHubのモデルカードに自動的に追加できます:

model.save_to_hub(
    "distilroberta-base-sentence-transformer", 
    organization= # ユーザー名を追加してください
    train_datasets=["embedding-data/QQP_triplets"],
    )

私はNotebook Companionでembedding-data/sentence-compressionデータセットとMultipleNegativesRankingLoss損失を使用して、この同じモデルをファインチューニングしました。

Sentence Transformersの制限事項は何ですか?

Sentence Transformersモデルは、意味的な検索において単純なTransformersモデルよりもはるかに優れた性能を発揮します。しかし、Sentence Transformersモデルがうまく機能しない場合はどこですか?タスクが分類である場合、文の埋め込みを使用することは間違ったアプローチです。その場合、🤗 Transformersライブラリがより適した選択肢となります。

追加リソース

  • 埋め込みを使った始め方。
  • 意味的な検索の理解。
  • 最初のSentence Transformersモデルを始めましょう。
  • Sentence Transformersを使用してプレイリストを生成する。
  • Hugging Face + Sentence Transformersのドキュメント。

読んでいただきありがとうございました!埋め込み作成をお楽しみください。

We will continue to update VoAGI; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

AIテクノロジー

アンソニー・グーネティレケ氏は、Amdocsのグループ社長であり、テクノロジー部門および戦略部門の責任者です- インタビューシリーズ

アンソニー・グーネティレーケは、Amdocsでグループ社長、テクノロジーと戦略担当です彼と企業戦略チームは、会社の戦略を策...

人工知能

『ジュリエット・パウエル&アート・クライナー、The AI Dilemma – インタビューシリーズの著者』

『AIのジレンマ』は、ジュリエット・パウエルとアート・クライナーによって書かれましたジュリエット・パウエルは、著者であ...

人工知能

「Kognitosの創設者兼CEO、ビニー・ギル- インタビューシリーズ」

ビニー・ギルは、複数の役職と企業を横断する多様で幅広い業務経験を持っていますビニーは現在、Kognitosの創設者兼CEOであり...

データサイエンス

アステラソフトウェアのCOO、ジェイ・ミシュラ - インタビューシリーズ

ジェイ・ミシュラは、急速に成長しているエンタープライズ向けデータソリューションの提供企業であるAstera Softwareの最高執...

人工知能

「LeanTaaSの創設者兼CEO、モハン・ギリダラダスによるインタビューシリーズ」

モーハン・ギリダラダスは、AIを活用したSaaSベースのキャパシティ管理、スタッフ配置、患者フローのソフトウェアを提供する...

人工知能

「マーシャンの共同創設者であるイータン・ギンスバーグについてのインタビューシリーズ」

エタン・ギンズバーグは、マーシャンの共同創業者であり、すべてのプロンプトを最適なLLMに動的にルーティングするプラットフ...