マルチリンガルASRのためのWhisperの調整を行います with 🤗 Transformers
'WhisperのマルチリンガルASRの調整を🤗 Transformersで行います'
このブログでは、ハギングフェイス🤗トランスフォーマーを使用して、Whisperを任意の多言語ASRデータセットに対して細かく調整する手順を段階的に説明します。このブログでは、Whisperモデル、Common Voiceデータセット、および細かな調整の背後にある理論について詳しく説明し、データの準備と細かい調整の手順を実行するためのコードセルと共に提供しています。説明は少ないですが、すべてのコードがあるより簡略化されたバージョンのノートブックは、関連するGoogle Colabを参照してください。
目次
- はじめに
- Google ColabでのWhisperの細かい調整
- 環境の準備
- データセットの読み込み
- 特徴抽出器、トークナイザー、およびデータの準備
- トレーニングと評価
- デモの作成
- 締めくくり
はじめに
Whisperは、OpenAIのAlec Radfordらによって2022年9月に発表された自動音声認識(ASR)のための事前学習モデルです。Whisperは、Wav2Vec 2.0などの先行研究とは異なり、ラベル付きの音声トランスクリプションデータで事前学習されています。具体的には、680,000時間のデータが使用されています。これは、Wav2Vec 2.0の訓練に使用されるラベルなしの音声データ(60,000時間)よりも桁違いに多いデータです。さらに、この事前学習データのうち117,000時間が多言語ASRデータです。これにより、96以上の言語に適用できるチェックポイントが生成され、その多くは低リソース言語とされています。
このような大量のラベル付きデータにより、Whisperは事前学習データから音声認識の教師ありタスクを直接学習し、音声トランスクリプションデータからテキストへのマッピングを学習します。そのため、Whisperはパフォーマンスの高いASRモデルを得るためにほとんど追加の細かい調整を必要としません。これに対して、Wav2Vec 2.0は非教師付きタスクのマスク予測で事前学習されており、音声から隠れた状態への中間的なマッピングを学習します。非教師付きの事前学習は音声の高品質な表現を生み出しますが、音声からテキストへのマッピングは学習されません。このマッピングは細かい調整中にのみ学習されるため、競争力のあるパフォーマンスを得るにはより多くの細かい調整が必要です。
680,000時間のラベル付き事前学習データにスケールされると、Whisperモデルは多くのデータセットとドメインに対して高い汎化能力を示します。事前学習されたチェックポイントは、LibriSpeech ASRのtest-cleanサブセットで約3%の単語エラーレート(WER)を達成し、TED-LIUMでは4.7%のWERで新たな最先端の結果を実現します(Whisper論文の表8を参照)。Whisperが事前学習中に獲得した多言語ASRの知識は、他の低リソース言語に活用することができます。細かい調整により、事前学習済みのチェックポイントを特定のデータセットと言語に適応させることで、これらの結果をさらに改善することができます。
Whisperは、Transformerベースのエンコーダーデコーダーモデルであり、シーケンスからシーケンスへのモデルとも呼ばれています。Whisperは、オーディオのスペクトログラム特徴のシーケンスをテキストトークンのシーケンスにマッピングします。まず、生のオーディオ入力は特徴抽出器によってログメルスペクトログラムに変換されます。次に、Transformerエンコーダーはスペクトログラムをエンコードしてエンコーダーの隠れ状態のシーケンスを形成します。最後に、デコーダーはエンコーダーの隠れ状態と以前に予測されたトークンの両方に依存して、テキストトークンを自己回帰的に予測します。図1はWhisperモデルを要約しています。
シーケンス・トゥ・シーケンスモデルでは、エンコーダは音声入力を隠れた状態表現のセットに変換し、話された音声から重要な特徴を抽出します。デコーダは言語モデルの役割を果たし、隠れた状態表現を処理し、対応するテキストの転写を生成します。システムアーキテクチャ内に言語モデルを内部に組み込むことは、ディープフュージョンと呼ばれます。これは、CTC + n n n -gramなどのエンコーダと外部で言語モデルを組み合わせる浅いフュージョンとは対照的です(内部言語モデルの推定を参照)。ディープフュージョンでは、同じトレーニングデータと損失関数でエンドツーエンドでシステム全体をトレーニングすることができるため、より柔軟性があり、一般的に優れたパフォーマンスが得られます(ESBベンチマークを参照)。
Whisperは、クロスエントロピーの目的関数を使用して事前トレーニングおよび微調整され、分類タスクにおけるシーケンス・トゥ・シーケンスシステムのトレーニングに標準的な目的関数を使用しています。ここでは、システムが事前定義されたテキストトークンの認識を正しく分類するようにトレーニングされます。
Whisperのチェックポイントは、モデルサイズの異なる5つの設定で提供されています。最も小さい4つは英語のみまたは多言語のデータでトレーニングされています。最大のチェックポイントは多言語のみです。事前トレーニングされた9つのチェックポイントは、Hugging Face Hubで利用できます。チェックポイントは、以下の表にまとめられ、ハブ上のモデルへのリンクが提供されています。
デモンストレーションの目的で、244Mパラメータ(約1GB)のsmall
チェックポイントの多言語バージョンを微調整します。データについては、Common Voiceデータセットから取得したリソースの少ない言語でシステムをトレーニングおよび評価します。わずか8時間の微調整データでも、この言語で強力なパフォーマンスを実現できることを示します。
1 {}^1 1 「Whisper」という名前は、「WSPSR」という頭字語から派生しています。これは、「Web-scale Supervised Pre-training for Speech Recognition」を意味します。
Google ColabでWhisperを微調整する
環境の準備
Whisperモデルを微調整するために、いくつかの人気のあるPythonパッケージを使用します。トレーニングデータをダウンロードして準備するためにdatasets
を使用し、Whisperモデルをロードしてトレーニングするためにtransformers
を使用します。オーディオファイルの前処理にsoundfile
パッケージが必要であり、モデルのパフォーマンスを評価するためにevaluate
およびjiwer
が必要です。最後に、微調整されたモデルの派手なデモを作成するためにgradio
を使用します。
!pip install datasets>=2.6.1
!pip install git+https://github.com/huggingface/transformers
!pip install librosa
!pip install evaluate>=0.30
!pip install jiwer
!pip install gradio
トレーニング中にモデルのチェックポイントをHugging Face Hubに直接アップロードすることを強くお勧めします。Hubでは以下の機能が提供されます:
- 統合バージョン管理:トレーニング中にモデルのチェックポイントが失われないことが保証されます。
- Tensorboardログ:トレーニングの過程で重要なメトリクスを追跡できます。
- モデルカード:モデルの機能と意図される使用方法について文書化します。
- コミュニティ:コミュニティとの共有や協力のための簡単な方法です!
ノートブックをHubにリンクするには、プロンプトにHubの認証トークンを入力するだけです。Hubの認証トークンはこちらで確認できます:
from huggingface_hub import notebook_login
notebook_login()
出力結果:
ログインに成功しました
トークンは/root/.huggingface/tokenに保存されました
データセットの読み込み
Common Voiceは、スピーカーがさまざまな言語でWikipediaのテキストを録音したクラウドソーシングデータセットのシリーズです。Common Voiceデータセットの最新版(バージョン11)を使用します。言語については、インド・アーリア語派の言語であるヒンディー語でモデルを微調整します。Common Voice 11.0には、おおよそ12時間のラベル付きヒンディー語データが含まれており、そのうち4時間はテストデータとして保持されています。
Common Voiceのデータセットページに移動して、Common Voiceの詳細を表示しましょう:mozilla-foundation/common_voice_11_0。
このページを初めて表示すると、利用規約の承認を求められます。承認すると、データセットに完全にアクセスできるようになります。
データセットの使用許可を提供した後、データセットのプレビューが表示されます。データセットのプレビューでは、データセットの最初の100サンプルが表示されます。さらに、リアルタイムで聞くことができる音声サンプルも読み込まれています。ドロップダウンメニューを使用して、Hindiのサブセットを選択することで、Common VoiceのHindiデータセットを選択することができます(Hindiの言語識別コードであるhi
を使用します):
最初のサンプルの再生ボタンを押すと、音声を聞くことができ、それに対応するテキストも表示されます。トレーニングセットとテストセットのサンプルをスクロールして、扱っている音声データとテキストデータの感触をより良く掴むことができます。イントネーションやスタイルから、これらの録音がナレーションのスピーチから取られていることがわかります。また、クラウドソーシングされたデータの共通の特徴である話者と録音品質の大きなバリエーションもおそらく気付くでしょう。
🤗 Datasetsを使用すると、データのダウンロードと準備が非常に簡単になります。たった1行のコードでCommon Voiceの分割データをダウンロードして準備することができます。Hindiは非常にリソースが少ないため、train
とvalidation
の分割データを組み合わせて、約8時間のトレーニングデータを作成します。4時間のtest
データをホールドアウトテストセットとして使用します:
from datasets import load_dataset, DatasetDict
common_voice = DatasetDict()
common_voice["train"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="train+validation", use_auth_token=True)
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test", use_auth_token=True)
print(common_voice)
出力結果:
DatasetDict({
train: Dataset({
features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
num_rows: 6540
})
test: Dataset({
features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
num_rows: 2894
})
})
ほとんどのASRデータセットは、入力オーディオサンプル(audio
)と対応する転写テキスト(sentence
)のみを提供します。Common Voiceには、ASRには関係のないaccent
やlocale
などの追加のメタデータ情報も含まれています。このノートブックをできるだけ汎用的に保つため、ファインチューニングには入力オーディオと転写テキストのみを考慮し、追加のメタデータ情報は無視します:
common_voice = common_voice.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])
Common Voiceは多言語ASRデータセットの一例に過ぎず、私たちがダウンロードできるデータはまだたくさんあります!音声認識用の利用可能なデータセットの範囲を表示するには、以下のリンクを参照してください:ハブ上のASRデータセット。
特徴抽出器、トークナイザー、データの準備
ASRパイプラインは、次の3つのコンポーネントに分解できます:
- 生のオーディオ入力を前処理する特徴抽出器
- シーケンスからシーケンスへのマッピングを行うモデル
- モデルの出力をテキスト形式に後処理するトークナイザー
🤗 Transformersでは、WhisperモデルにはWhisperFeatureExtractorとWhisperTokenizerという関連する特徴抽出器とトークナイザーがあります。
特徴抽出器とトークナイザーの詳細については、順番に説明していきます!
WhisperFeatureExtractorの読み込み
音声は、時間とともに変化する1次元配列で表されます。ある時点の配列の値は、その点の信号の振幅です。振幅情報だけから、オーディオの周波数スペクトルを再構成し、すべての音響特徴を回復することができます。
音声は連続しているため、無限の振幅値を含んでいます。これは、有限の配列を期待するコンピュータデバイスにとって問題を引き起こします。したがって、信号を固定された時間ステップでサンプリングすることにより、音声信号を離散化します。オーディオをサンプリングする間隔は、サンプリングレートと呼ばれ、通常はサンプル/秒またはヘルツ(Hz)で測定されます。サンプリングレートを高くすると、連続的な音声信号のより良い近似が得られますが、秒あたりの値をより多く保存する必要があります。
オーディオ入力のサンプリングレートをモデルが期待するサンプリングレートに一致させることは重要です。なぜなら、異なるサンプリングレートのオーディオ信号は非常に異なる分布を持っているからです。オーディオサンプルは常に正しいサンプリングレートで処理されるべきです。そうしないと、予期しない結果が生じることがあります!例えば、サンプリングレートが16kHzのオーディオサンプルを、サンプリングレートが8kHzの環境で再生すると、オーディオは倍速で再生されるように聞こえます。同様に、間違ったサンプリングレートでオーディオを渡すと、期待するサンプリングレートを持つASRモデルが失敗することがあります。Whisper特徴抽出器は、サンプリングレートが16kHzのオーディオ入力を期待しているため、入力をこの値に一致させる必要があります。スローモーションの音声でASRシステムを誤ってトレーニングすることは避けたいです!
Whisper特徴抽出器は2つの操作を行います。まず、音声サンプルのバッチを30秒の入力長さにパッド/切り詰めます。30秒より短いサンプルは、シーケンスの末尾にゼロを追加して30秒にパッドされます(オーディオ信号のゼロは信号なしまたは無音を表します)。30秒より長いサンプルは30秒に切り詰められます。バッチ内のすべての要素が入力空間で最大長さにパッド/切り詰められているため、Whisperモデルにオーディオ入力を転送する際にはアテンションマスクは必要ありません。Whisperはこの点でユニークであり、ほとんどのオーディオモデルでは、シーケンスがパッディングされた場所やセルフアテンションメカニズムで無視されるべき場所などを示すアテンションマスクを提供することが期待されます。Whisperはアテンションマスクなしで動作し、音声信号から直接入力の無視する場所を推測するようにトレーニングされています。
Whisper特徴抽出器が行う2番目の操作は、パッドされたオーディオ配列を対数メルスペクトログラムに変換することです。これらのスペクトログラムは信号の周波数の視覚的な表現であり、フーリエ変換のようなものです。図2には例としてスペクトログラムが表示されています。y軸にはMelチャンネルがあり、特定の周波数ビンに対応しています。x軸は時間です。各ピクセルの色は、その周波数ビンの対数強度を示しています。対数メルスペクトログラムがWhisperモデルが期待する入力形式です。
Melチャンネル(周波数ビン)は音声処理において標準的であり、人間の聴覚範囲を近似するように選ばれています。Whisperの微調整に必要なのは、スペクトログラムが音声信号の周波数の視覚的な表現であるということだけです。Melチャンネルの詳細については、Mel周波数ケプストラムを参照してください。
幸いにも、🤗 Transformers Whisper特徴抽出器は、たった一行のコードでパディングとスペクトログラム変換を実行します!事前学習済みのチェックポイントから特徴抽出器をロードし、オーディオデータに備えておきましょう:
from transformers import WhisperFeatureExtractor
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
WhisperTokenizerのロード
次に、Whisperトークナイザのロード方法を見てみましょう。Whisperモデルはテキストトークンを出力し、予測されたテキストのインデックスを語彙アイテムの辞書内で示します。トークナイザはテキストトークンのシーケンスを実際のテキスト文字列にマップします(例:[1169、3797、3332] -> “the cat sat”)。
通常、ASRのためにエンコーダのみのモデルを使用する場合、我々はConnectionist Temporal Classification(CTC)を使用してデコードします。ここでは、使用するデータセットごとにCTCトークナイザをトレーニングする必要があります。エンコーダ-デコーダのアーキテクチャを使用する利点の1つは、事前学習済みモデルからトークナイザを直接利用できることです。
Whisperトークナイザは、96の事前学習言語の転写で事前学習されています。そのため、ほとんどの多言語ASRアプリケーションに適した包括的なバイトペアを持っています。ヒンディー語の場合、トークナイザをロードし、さらなる変更なしで微調整に使用することができます。単にターゲット言語とタスクを指定するだけで、これらの引数はエンコードされたラベルシーケンスの先頭に言語とタスクのトークンを接頭辞として追加するようトークナイザに指示します:
from transformers import WhisperTokenizer
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")
トークナイザーがヒンディー語の文字を正しくエンコードすることを確認するために、Common Voiceデータセットの最初のサンプルをエンコードしてデコードできます。トークナイザーは、トランスクリプトの開始と終了を示す特別なトークン、言語トークン、およびタスクトークン(前のステップで指定した引数によって指定されます)をシーケンスの先頭と末尾に追加します。ラベルIDをデコードする際には、これらの特別なトークンをスキップするオプションがあり、元の入力形式の文字列を返すことができます:
input_str = common_voice["train"][0]["sentence"]
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)
print(f"Input: {input_str}")
print(f"Decoded w/ special: {decoded_with_special}")
print(f"Decoded w/out special: {decoded_str}")
print(f"Are equal: {input_str == decoded_str}")
出力結果を表示:
Input: खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई
Decoded w/ special: <|startoftranscript|><|hi|><|transcribe|><|notimestamps|>खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई<|endoftext|>
Decoded w/out special: खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई
Are equal: True
WhisperProcessorを作成する
フィーチャー抽出器とトークナイザーを簡単に使用するために、両方を1つのWhisperProcessor
クラスにまとめることができます。このプロセッサオブジェクトは、WhisperFeatureExtractor
とWhisperProcessor
を継承しており、必要に応じてオーディオ入力とモデルの予測に使用することができます。これにより、トレーニング中に追跡する必要があるオブジェクトは2つだけになります:processor
とmodel
です:
from transformers import WhisperProcessor
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")
データの準備
データの形式を確認するために、Common Voiceデータセットの最初の例を表示してみましょう:
print(common_voice["train"][0])
出力結果を表示:
{'audio': {'path': '/home/sanchit_huggingface_co/.cache/huggingface/datasets/downloads/extracted/607848c7e74a89a3b5225c0fa5ffb9470e39b7f11112db614962076a847f3abf/cv-corpus-11.0-2022-09-21/hi/clips/common_voice_hi_25998259.mp3',
'array': array([0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ..., 9.6724887e-07,
1.5334779e-06, 1.0415988e-06], dtype=float32),
'sampling_rate': 48000},
'sentence': 'खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई'}
1次元の入力オーディオ配列と対応するターゲットの転写があることがわかります。サンプリングレートの重要性について詳しく説明しましたが、オーディオのサンプリングレートをWhisperモデル(16kHz)に合わせる必要があります。入力オーディオは48kHzでサンプリングされているため、Whisperフィーチャー抽出器に渡す前に16kHzにダウンサンプリングする必要があります。
データセットの cast_column
メソッドを使用して、オーディオ入力を正しいサンプリングレートに設定します。この操作はオーディオを直接変更するのではなく、オーディオサンプルが最初に読み込まれる際に datasets
にオンザフライでリサンプルするためのシグナルを送ります。
from datasets import Audio
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
Common Voice データセットの最初のオーディオサンプルを再ロードすると、目的のサンプリングレートにリサンプルされます。
print(common_voice["train"][0])
出力結果:
{'audio': {'path': '/home/sanchit_huggingface_co/.cache/huggingface/datasets/downloads/extracted/607848c7e74a89a3b5225c0fa5ffb9470e39b7f11112db614962076a847f3abf/cv-corpus-11.0-2022-09-21/hi/clips/common_voice_hi_25998259.mp3',
'array': array([ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
-3.4206650e-07, 3.2979898e-07, 1.0042874e-06], dtype=float32),
'sampling_rate': 16000},
'sentence': 'खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई'}
素晴らしいです!サンプリングレートが16kHzにダウンサンプルされたことがわかります。配列の値も異なっており、以前はおおよそ3つの振幅値が1つになっています。
それでは、モデルのためにデータを準備するための関数を作成しましょう:
batch["audio"]
を呼び出してオーディオデータをロードおよびリサンプルします。上述したように、🤗 Datasets は必要なリサンプリング操作をオンザフライで実行します。- フィーチャーエクストラクタを使用して、1次元のオーディオ配列からログメルスペクトログラムの入力特徴量を計算します。
- トランスクリプションをラベルIDにエンコードします。
def prepare_dataset(batch):
# 48kHzから16kHzへのオーディオデータのロードおよびリサンプリング
audio = batch["audio"]
# 入力オーディオ配列からログメル入力特徴量を計算
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# 対象テキストをラベルIDにエンコード
batch["labels"] = tokenizer(batch["sentence"]).input_ids
return batch
データ準備関数をデータセットの全ての訓練例に適用するために、.map
メソッドを使用できます:
common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=4)
よし!これで訓練のためにデータが完全に準備されました!次に、このデータを使用して Whisper をファインチューニングする方法を見ていきましょう。
注意:現在、datasets
はオーディオの読み込みとリサンプリングに torchaudio
と librosa
の両方を使用しています。独自のカスタマイズされたデータの読み込み/サンプリングを実装したい場合は、"path"
列を使用してオーディオファイルのパスを取得し、"audio"
列を無視することができます。
訓練と評価
データの準備ができたので、訓練パイプラインに入る準備ができました。🤗 Trainer が大部分の作業を行ってくれます。私たちが行う必要があるのは次のことです:
-
データコレータの定義:データコレータは、前処理されたデータを受け取り、モデルに適した PyTorch テンソルに準備します。
-
評価メトリクス:評価時には、単語エラーレート(WER)メトリクスを使用してモデルを評価したいです。この計算を処理する
compute_metrics
関数を定義する必要があります。 -
事前学習済みのチェックポイントの読み込み:事前学習済みのチェックポイントを読み込み、訓練のために正しく設定する必要があります。
-
訓練引数の定義:これらは訓練スケジュールを構築するために 🤗 Trainer によって使用されます。
モデルの調整が完了したら、テストデータで評価して、ヒンディー語の音声を正しく転写できるようにトレーニングされていることを確認します。
データコレータを定義する
シーケンス対シーケンス音声モデルのデータコレータは、input_features
とlabels
を独立して扱う点でユニークです。つまり、input_features
は特徴抽出器によって処理され、labels
はトークナイザによって処理されます。
input_features
はすでに30秒にパディングされ、固定次元のログメルスペクトログラムに変換されているため、バッチ化されたPyTorchテンソルに変換するだけで済みます。これは、特徴抽出器の.pad
メソッドを使用してreturn_tensors=pt
として行います。ここでは追加のパディングは適用されません。入力は固定次元なので、input_features
は単純にPyTorchテンソルに変換されます。
一方、labels
はパディングされていません。まず、トークナイザの.pad
メソッドを使用してバッチ内の最大長にシーケンスをパディングします。次に、パディングトークンは-100
で置き換えられます。これにより、これらのトークンが損失の計算時に考慮されないようになります。そして、トレーニング中にラベルシーケンスの先頭にトランスクリプトトークンを追加するため、ラベルシーケンスの先頭からトランスクリプトトークンを切り取ります。
以前定義したWhisperProcessor
を使用して、特徴抽出器とトークナイザの操作を同時に実行できます:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# 入力とラベルを別々の長さで処理するために分割する
# 入力の音声は、単純にtorchテンソルを返すことで処理します
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# トークン化されたラベルシーケンスを取得します
label_features = [{"input_ids": feature["labels"]} for feature in features]
# ラベルを最大長にパディングします
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# パディングを-100で置き換えて、損失を正しく無視します
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# 前のトークン化ステップでbosトークンが追加されている場合は、後で追加されるのでここで切り取ります
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
さきほど定義したデータコレータを初期化しましょう:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
評価メトリクス
次に、評価セットで使用する評価メトリクスを定義します。ここでは、ASRシステムの評価において「事実上の」メトリクスであるWord Error Rate (WER) メトリクスを使用します。詳細については、WERドキュメントを参照してください。WERメトリクスは🤗 Evaluateから読み込まれます:
import evaluate
metric = evaluate.load("wer")
次に、モデルの予測を受け取り、WERメトリクスを返す関数compute_metrics
を定義するだけです。この関数では、-100
をpad_token_id
で置き換えます(損失の計算時にパディングトークンを正しく無視するためにデータコレータで適用した手順を元に戻します)。それから、予測されたIDとラベルIDを文字列にデコードします。最後に、予測と参照ラベル間のWERを計算します:
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# -100をpad_token_idで置き換える
label_ids[label_ids == -100] = tokenizer.pad_token_id
# メトリクスの計算時にトークンをグループ化したくないので、スペシャルトークンをスキップして予測とラベルをデコードする
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
事前学習済みのチェックポイントの読み込み
では、事前学習済みのWhisper small
チェックポイントを読み込みましょう。再度、これは🤗 Transformersを使用することで簡単に行うことができます!
from transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
Whisperモデルは、自己回帰生成が開始される前にモデル出力として強制されるトークンID(forced_decoder_ids
)を持っています。これらのトークンIDは、ゼロショットASRのためのトランスクリプション言語とタスクを制御します。ファインチューニングでは、これらのIDをNone
に設定します。なぜなら、モデルを正しい言語(ヒンディー語)とタスク(トランスクリプション)を予測するように訓練するためです。また、生成中に完全に抑制されるトークン(suppress_tokens
)もあります。これらのトークンの対数確率は-inf
に設定されているため、サンプリングされることはありません。これらのトークンを空のリストにオーバーライドして、トークンを抑制しないようにします:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
トレーニング引数の定義
最後のステップでは、トレーニングに関連するすべてのパラメータを定義します。以下にいくつかのパラメータを説明します:
output_dir
:モデルの重みを保存するローカルディレクトリ。これはまたHugging Face Hub上のリポジトリ名になります。generation_max_length
:評価中に自己回帰的に生成するトークンの最大数。save_steps
:トレーニング中に中間のチェックポイントが保存され、save_steps
ごとに非同期でHubにアップロードされます。eval_steps
:トレーニング中に中間のチェックポイントの評価がeval_steps
ごとに実行されます。report_to
:トレーニングログの保存先。サポートされているプラットフォームは"azure_ml"
、"comet_ml"
、"mlflow"
、"neptune"
、"tensorboard"
、"wandb"
です。お好きなものを選択するか、"tensorboard"
のままにしてHubにログを保存します。
その他のトレーニング引数の詳細については、Seq2SeqTrainingArgumentsのドキュメントを参照してください。
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir="./whisper-small-hi", # 好きなリポジトリ名に変更してください
per_device_train_batch_size=16,
gradient_accumulation_steps=1, # バッチサイズを2倍減らすごとに2倍増やしてください
learning_rate=1e-5,
warmup_steps=500,
max_steps=4000,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_steps=1000,
eval_steps=1000,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=True,
)
注意:モデルのチェックポイントをHubにアップロードしたくない場合は、push_to_hub=False
と設定します。
トレーニング引数を🤗 Trainerにモデル、データセット、データコレータ、およびcompute_metrics
関数と共に渡すことができます:
from transformers import Seq2SeqTrainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=common_voice["train"],
eval_dataset=common_voice["test"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
これで、トレーニングを開始する準備が整いました!
トレーニング
トレーニングを開始するには、単純に次のコードを実行してください:
trainer.train()
トレーニングには、GPUの性能やGoogle Colabに割り当てられたGPUによって約5〜10時間かかります。GPUによっては、トレーニングを開始するとCUDAの"out-of-memory"
エラーが発生する可能性があります。その場合は、per_device_train_batch_size
を2の倍数で徐々に減らし、gradient_accumulation_steps
を使用して対処することができます。
出力結果:
私たちの最良の WER は32.0%です。8時間のトレーニングデータでは悪くありません!大きな問題は、これが他のASRシステムと比較してどのようなものかです。そのためには、hf-speech-bench
というリーダーボードを表示することができます。このリーダーボードは、言語とデータセットによってモデルを分類し、その後、WERに基づいてランキングします。
私たちのファインチューニングされたモデルは、Whisperのゼロショットパフォーマンスを大幅に改善しており、Whisperの強力な転移学習能力を示しています。
トレーニング結果をハブにプッシュすると、チェックポイントをリーダーボードに自動的に送信することができます。適切なキーワード引数(kwargs)を設定するだけです。これらの値をデータセット、言語、およびモデル名に合わせて変更できます:
kwargs = {
"dataset_tags": "mozilla-foundation/common_voice_11_0",
"dataset": "Common Voice 11.0", # トレーニングデータセットの「わかりやすい」名前
"dataset_args": "config: hi, split: test",
"language": "hi",
"model_name": "Whisper Small Hi - Sanchit Gandhi", # モデルの「わかりやすい」名前
"finetuned_from": "openai/whisper-small",
"tasks": "automatic-speech-recognition",
"tags": "hf-asr-leaderboard",
}
トレーニング結果をハブにアップロードすることができます。そのためには、push_to_hub
コマンドを実行します:
trainer.push_to_hub(**kwargs)
ハブ上のリンクを使用して、誰とでもこのモデルを共有できます。また、次の識別子を使用してロードすることもできます:"your-username/the-name-you-picked"
、例えば:
from transformers import WhisperForConditionalGeneration, WhisperProcessor
model = WhisperForConditionalGeneration.from_pretrained("sanchit-gandhi/whisper-small-hi")
processor = WhisperProcessor.from_pretrained("sanchit-gandhi/whisper-small-hi")
ファインチューニングされたモデルは、Common Voiceのヒンディー語のテストデータで満足のいく結果をもたらしますが、最適とは言えません。このノートブックの目的は、事前学習済みのWhisperチェックポイントをどの多言語ASRデータセットでもファインチューニングする方法を示すことです。学習率やドロップアウトなどのトレーニングハイパーパラメータを最適化し、より大きな事前学習済みチェックポイント(VoAGI
またはlarge
)を使用することで、結果をより改善できる可能性があります。
デモの作成
モデルをファインチューニングしたので、ASRの機能を示すデモを作成することができます!🤗 Transformersのpipeline
を使用します。これにより、オーディオ入力の前処理からモデルの予測をデコードするまで、ASRパイプライン全体を簡単に処理できます。対話型デモをGradioで作成します。Gradioは、機械学習デモを作成する最も簡単な方法の一つです。Gradioを使用すると、わずか数分でデモを作成することができます!
以下の例を実行すると、Gradioデモが生成されます。このデモでは、コンピュータのマイクを介して音声を録音し、ファインチューニングされたWhisperモデルに入力して対応するテキストを転写します:
from transformers import pipeline
import gradio as gr
pipe = pipeline(model="sanchit-gandhi/whisper-small-hi") # "your-username/the-name-you-picked"に変更
def transcribe(audio):
text = pipe(audio)["text"]
return text
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(source="microphone", type="filepath"),
outputs="text",
title="Whisper Small Hindi",
description="リアルタイムデモ:ファインチューニングされたWhisper Smallモデルを使用したヒンディー語音声認識。",
)
iface.launch()
締めくくり
このブログでは、🤗 Datasets、Transformers、およびHugging Face Hubを使用して、多言語ASRのためにWhisperをファインチューニングする手順を詳しく説明しました。自分自身でファインチューニングを試してみたい場合は、Google Colabを参照してください。英語や多言語のASRのために他のTransformersモデルをファインチューニングすることに興味がある場合は、examples/pytorch/speech-recognitionの例のスクリプトをチェックしてください。
We will continue to update VoAGI; if you have any questions or suggestions, please contact us!
Was this article helpful?
93 out of 132 found this helpful
Related articles