「LoRAを使用してAmazon SageMakerでWhisperモデルを微調整する」

「LoRAを使ってAmazon SageMakerでWhisperモデルを微調整する方法」

Whisperは、ウェブからの680,000時間の監視されたデータを使用してトレーニングされた自動音声認識(ASR)モデルであり、さまざまな言語とタスクを対象としています。その制約の一つは、マラーティー語やドラヴィダ語などの低リソース言語での低性能ですが、これは微調整で対処できます。ただし、Whisperモデルの微調整は、計算リソースとストレージの要件の両方の面で非常に困難となっています。Whisperモデルの完全な微調整を5〜10回行うと、約100時間のA100 GPU(40GB SXM4)が必要となります(モデルのサイズとパラメータによって異なります)、また、各微調整チェックポイントには約7GBのストレージスペースが必要です。この高い計算およびストレージの要求は、リソースが限られている環境では大きな障害となる可能性があり、意義のある結果を実現することが非常に困難になります。

低ランク適応、またはLoRAとしても知られている手法は、モデルの微調整には独自のアプローチを取ります。この手法では、事前トレーニング済みのモデルの重みを静的な状態に保ち、各Transformer構造のレイヤーに学習可能なランク分解行列を導入します。この手法により、下流タスクに必要な学習可能なパラメータの数を10,000分の1に減らし、GPUメモリの要件を3倍に削減することが可能です。モデルの品質については、LoRAは伝統的な微調整手法と同等またはそれ以上の性能を示しており、学習可能なパラメータが少ない状態で動作します(オリジナルのLoRA論文の結果を参照)。さらに、トレーニングのスループットが向上するという利点もあります。アダプタアダプタの手法とは異なり、LoRAは推論時に追加のレイテンシーを導入せず、展開フェーズでのモデルの効率を維持します。LoRAを使用してWhisperを微調整することは、有望な結果を示しています。例えば、Whisper-Large-v2を考えてみましょう:8GBメモリGPU上の12時間の一般音声データセットで3エポック実行すると、6〜8時間かかります。これは、比較可能な性能で完全な微調整に比べて5倍高速です。

Amazon SageMakerは、WhisperのLoRA微調整を実装するための理想的なプラットフォームです。Amazon SageMakerは、完全に管理されたインフラストラクチャ、ツール、ワークフローを使用して、任意のユースケースに対して機械学習モデルを構築、トレーニング、展開することができます。追加のモデルトレーニングの利点には、マネージドスポットトレーニングによる低コストのトレーニング、AWS GPUインスタンス間でモデルとトレーニングデータセットを分割するための分散トレーニングライブラリなどがあります。もっと詳しく。トレーニングしたSageMakerモデルは、SageMaker上で直接推論にデプロイすることができます。この記事では、SageMakerでLoRA微調整を実装するためのステップバイステップのガイドを紹介します。この実装に関連するソースコードはGitHubで入手できます。

微調整用のデータセットの準備

微調整のタスクには、低リソース言語であるマラーティー語を使用します。Hugging Faceのデータセットライブラリを使用すると、Common Voiceデータセットをトレーニング用とテスト用のデータセットにダウンロードして分割することができます。次のコードを参照してください。

from datasets import load_dataset, DatasetDict
language = "Marathi"
language_abbr = "mr"
task = "transcribe"
dataset_name = "mozilla-foundation/common_voice_11_0"
common_voice = DatasetDict()
common_voice["train"] = load_dataset(dataset_name, language_abbr, split="train+validation", use_auth_token=True)
common_voice["test"] = load_dataset(dataset_name, language_abbr, split="test", use_auth_token=True)

Whisper音声認識モデルでは、オーディオ入力を16kHzのモノラル16ビット符号付き整数WAVファイルとする必要があります。Common Voiceデータセットは48Kサンプリングレートなので、オーディオファイルをダウンサンプリングする必要があります。そして、Whisperの特徴抽出器をオーディオに適用して対数メルスペクトログラム特徴量を抽出し、Whisperのトークナイザを用いてトランスクリプトの各文をトークンIDに変換する必要があります。次のコードを参照してください。

from transformers import WhisperFeatureExtractorfrom transformers import WhisperTokenizer
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)
tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=language, task=task)
def prepare_dataset(batch):
  # 48kHzから16kHzにオーディオデータを読み込みリサンプリングする
  audio = batch["audio"]
  batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
  batch["labels"] = tokenizer(batch["sentence"]).input_ids
  return batch
common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=2)
common_voice.save_to_disk("marathi-common-voice-processed")
!aws s3 cp --recursive "marathi-common-voice-processed" s3://<Your-S3-Bucket>

トレーニングサンプルをすべて処理した後、処理されたデータをAmazon S3にアップロードします。これにより、ファインチューニングステージで処理されたトレーニングデータを使用する際に、ローカルディスクにコピーする代わりにFastFileを使用してS3ファイルを直接マウントできます。

from sagemaker.inputs import TrainingInput
training_input_path=s3uritraining = TrainingInput(s3_data_type='S3Prefix', s3_data=training_input_path, distribution='FullyReplicated', input_mode='FastFile')

モデルのトレーニング

デモンストレーションでは、事前学習済みモデルとしてwhisper-large-v2を使用します(現在はwhisper v3も利用可能です)。これは、Hugging Faceのtransformersライブラリを介してインポートできます。トレーニング効率をさらに向上させるために、8ビットの量子化を使用できます。8ビットの量子化は、浮動小数点から8ビット整数に丸めることによるメモリ最適化を提供します。これは一般的に使用されるモデル圧縮技術であり、推論時の精度をあまり犠牲にせずにメモリ使用量を削減します。

8ビット量子化形式の事前学習済みモデルをロードするには、モデルのインスタンス化時にload_in_8bit=True引数を追加するだけです。次のコードのように、モデルの重みが8ビットに量子化され、メモリ使用量が削減されます。

from transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, load_in_8bit=True, device_map="auto")

このノートでは、Hugging FaceのpeftパッケージからLoRAの実装を使用します。LoRAを使用してモデルを微調整するには、次の4つのステップがあります。

  1. ベースモデルをインスタンス化します(前のステップと同様)。
  2. LoraConfigを作成し、LoRA固有のパラメータが定義されます。
  3. get_peft_model()でベースモデルをラップし、トレーニング可能なPeftModelを取得します。
  4. ベースモデルとしてPeftModelをトレーニングします。

以下のコードを参照してください。

from peft import LoraConfig, get_peft_model
config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")
model = get_peft_model(model, config)
training_args = Seq2SeqTrainingArguments(output_dir=args.model_dir, per_device_train_batch_size=int(args.train_batch_size), gradient_accumulation_steps=1, learning_rate=float(args.learning_rate), warmup_steps=args.warmup_steps, num_train_epochs=args.num_train_epochs, evaluation_strategy="epoch", fp16=True, per_device_eval_batch_size=args.eval_batch_size, generation_max_length=128, logging_steps=25, remove_unused_columns=False, label_names=["labels"])
trainer = Seq2SeqTrainer(args=training_args, model=model, train_dataset=train_dataset["train"], eval_dataset=train_dataset.get("test", train_dataset["test"]), data_collator=data_collator, tokenizer=processor.feature_extractor)

SageMakerトレーニングジョブを実行するために、独自のDockerコンテナを使用します。Dockerイメージは、GitHubからダウンロードできます。この中には、ffmpeg4とgit-lfsが他のPythonの要件とともにパッケージ化されています。独自のDockerコンテナをSageMakerで動作するように適応する方法の詳細については、独自のトレーニングコンテナの適応を参照してください。そして、Hugging Face Estimatorを使用してSageMakerトレーニングジョブを開始できます。

OUTPUT_PATH= f's3://{BUCKET}/{PREFIX}/{TRAINING_JOB_NAME}/output/'huggingface_estimator = HuggingFace(entry_point='train.sh',source_dir='./src',output_path= OUTPUT_PATH,instance_type=instance_type,instance_count=1,# transformers_version='4.17.0',# pytorch_version='1.10.2',py_version='py310',image_uri=<ECR-PATH>,role=ROLE,metric_definitions = metric_definitions,volume_size=200,distribution=distribution,keep_alive_period_in_seconds=1800,environment=environment,)huggingface_estimator.fit(job_name=TRAINING_JOB_NAME, wait=False)

LoRAの実装により、Whisperの大規模な微調整タスクを単一のGPUインスタンス(例:ml.g5.2xlarge)で実行することができました。対照的に、Whisperの大規模な完全な微調整タスクには複数のGPU(例:ml.p4d.24xlarge)とより長い学習時間が必要です。具体的には、私たちの実験では、完全な微調整タスクがLoRAのアプローチと比較して24倍のGPU時間を要することが示されました。

モデルのパフォーマンスを評価する

微調整されたWhisperモデルのパフォーマンスを評価するために、ホールドアウトテストセットで単語エラーレート(WER)を計算します。WERは、予測されたトランスクリプトと正解のトランスクリプトの差を測定します。WERが低いほど性能が良いことを示します。次のスクリプトを実行し、事前学習済みモデルと微調整モデルのWERの差を比較できます:

metric = evaluate.load("wer")eval_dataloader = DataLoader(common_voice["test"], batch_size=8, collate_fn=data_collator)model.eval()for step, batch in enumerate(tqdm(eval_dataloader)):with torch.cuda.amp.autocast():with torch.no_grad():generated_tokens = (model.generate(input_features=batch["input_features"].to("cuda"),decoder_input_ids=batch["labels"][:, :4].to("cuda"),max_new_tokens=255,).cpu().numpy())labels = batch["labels"].cpu().numpy()labels = np.where(labels != -100, labels, tokenizer.pad_token_id)decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)metric.add_batch(predictions=decoded_preds,references=decoded_labels,)del generated_tokens, labels, batchgc.collect()wer = 100 * metric.compute()print(f"{wer=}")

結論

本記事では、最新の音声認識モデルであるWhisperの微調整を実証しました。特に、Hugging FaceのPEFT LoRAを使用し、効率的なトレーニングのために8ビットの量子化を可能にしました。また、SageMakerでトレーニングジョブを実行する方法も示しました。

これは重要な最初のステップですが、Whisperモデルをさらに改善するためには、いくつかの方法があります。将来的には、SageMakerの分散トレーニングを使用して、より大規模なデータセットでトレーニングをスケーリングすることを検討してください。これにより、モデルはより多様で包括的なデータでトレーニングすることができ、精度が向上します。また、Whisperモデルのサービング時のレイテンシを最適化し、リアルタイム音声認識を実現することもできます。さらに、モデルアーキテクチャとトレーニングスキームの変更が必要となる長い音声トランスクリプションの処理にも取り組むことができます。

謝辞

本記事の洞察に対するParas Mehra、John Sol、Evandro Francoの有益なフィードバックとレビューに感謝します。

We will continue to update VoAGI; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

機械学習

ショッピファイの製品推奨アプリに生成AIを導入する

ショッピファイの製品推薦アプリケーションであるSearch and DiscoveryにジェネレーティブAIがどのように実装されたかについ...

データサイエンス

「ゼロからLLMを構築する方法」

「これは、大規模言語モデル(LLM)を実践的に使用するシリーズの6番目の記事です以前の記事では、プロンプトエンジニアリン...

AI研究

アップルの研究者がDeepPCRを公開:通常は順次処理される操作を並列化してニューラルネットワークの推論とトレーニングの速度を向上させる新しい機械学習アルゴリズム

人工知能や深層学習の進展により、さまざまな革新が実現されています。テキストや画像の合成、分割、分類などの複雑なタスク...

データサイエンス

「VAST DataのプラットフォームがAIイノベーションの障壁を取り除く方法」

データが存在する場所に関係なく、より多くのデータへの高速アクセスは、AIに基づくアプリケーション、ソリューション、およ...

データサイエンス

AIと機械学習のためのReactJS:強力な組み合わせ

このブログ記事では、ReactJSとAI/MLが組み合わされることで、パワフルでインタラクティブなウェブアプリケーションを構築す...

機械学習

人間とAIの協力

「AIと人間の知能の関係を探求する中で、最近のGenAIの出現は、その人間の知能を超越する能力について疑問を投げかけています」