🤗 Transformersを使用して、低リソースASRのためにXLSR-Wav2Vec2を微調整する
使用するのは、🤗 TransformersXLSR-Wav2Vec2を微調整して、低リソースASRに活用します
新着(11/2021):このブログ投稿は、XLSRの後継であるXLS-Rを紹介するように更新されました。
Wav2Vec2は、自動音声認識(ASR)のための事前学習モデルであり、Alexei Baevski、Michael Auli、Alex Conneauによって2020年9月にリリースされました。Wav2Vec2の優れた性能が、ASRの最も人気のある英語データセットであるLibriSpeechで示されるとすぐに、Facebook AIはWav2Vec2の多言語版であるXLSRを発表しました。XLSRはクロスリンガル音声表現を意味し、モデルが複数の言語で有用な音声表現を学習できる能力を指します。
XLSRの後継であるXLS-R(「音声用のXLM-R」という意味)は、Arun Babu、Changhan Wang、Andros Tjandraなどによって2021年11月にリリースされました。XLS-Rは、自己教師付き事前学習のために128の言語で約500,000時間のオーディオデータを使用し、パラメータ数が30億から200億までのサイズで提供されています。事前学習済みのチェックポイントは、🤗 Hubで見つけることができます:
- IPUを使用したHugging Face Transformersの始め方と最適化について
- スノーボールファイト ☃️をご紹介しますこれは私たちの最初のML-Agents環境です
- スクラッチからCodeParrot 🦜をトレーニングする
- Wav2Vec2-XLS-R-300M
- Wav2Vec2-XLS-R-1B
- Wav2Vec2-XLS-R-2B
BERTのマスクされた言語モデリング目的と同様に、XLS-Rは自己教師付き事前学習中に特徴ベクトルをランダムにマスクしてからトランスフォーマーネットワークに渡すことで、文脈化された音声表現を学習します(左側の図)。
ファインチューニングでは、事前学習済みネットワークの上に単一の線形層が追加され、音声認識、音声翻訳、音声分類などのラベル付きデータでモデルをトレーニングします(右側の図)。
XLS-Rは、公式論文のTable 3-6、Table 7-10、Table 11-12で、以前の最先端の結果に比べて音声認識、音声翻訳、話者/言語識別の両方で印象的な改善を示しています。
セットアップ
このブログでは、XLS-R(具体的には事前学習済みチェックポイントWav2Vec2-XLS-R-300M)をASRのためにファインチューニングする方法について詳しく説明します。
デモンストレーションの目的で、我々は低リソースなASRデータセットのCommon Voiceでモデルをファインチューニングします。このデータセットには検証済みのトレーニングデータが約4時間しか含まれていません。
XLS-Rは、音声認識や手書き認識など、シーケンス間の問題のためにニューラルネットワークをトレーニングするために使用されるアルゴリズムであるConnectionist Temporal Classification(CTC)を使用してファインチューニングされます。
Awni Hannunによるよく書かれたブログ記事「Sequence Modeling with CTC(2017)」をお勧めします。
始める前に、datasets
とtransformers
をインストールしてください。また、オーディオファイルを読み込むためにtorchaudio
と、単語エラーレート(WER)メトリックを使用してファインチューニングされたモデルを評価するためにjiwer
も必要です。
!pip install datasets==1.18.3
!pip install transformers==4.11.3
!pip install huggingface_hub==0.1
!pip install torchaudio
!pip install librosa
!pip install jiwer
トレーニング中にトレーニングチェックポイントをHugging Face Hubに直接アップロードすることを強くお勧めします。Hugging Face Hubには統合されたバージョン管理があるため、トレーニング中にモデルチェックポイントが失われることはありません。
これを行うには、Hugging Faceのウェブサイトから認証トークンを保存する必要があります(まだ登録していない場合はここでサインアップしてください)
from huggingface_hub import notebook_login
notebook_login()
出力結果:
ログインが成功しました
トークンは /root/.huggingface/token に保存されました
モデルのチェックポイントをアップロードするには、Git-LFSをインストールする必要があります:
apt install git-lfs
1 {}^1 1 論文では、モデルの評価には音素エラーレート(PER)が使用されましたが、ASRでは最も一般的な評価指標は単語エラーレート(WER)です。このノートブックをできるだけ一般的なものにするため、WERを使用してモデルを評価することにしました。
データ、トークナイザ、特徴抽出器の準備
ASRモデルは音声をテキストに変換するため、音声信号をモデルの入力形式(例:特徴ベクトル)に変換する特徴抽出器と、モデルの出力形式をテキストに変換するトークナイザの両方が必要です。
🤗 Transformersでは、XLS-RモデルにはWav2Vec2CTCTokenizerと呼ばれるトークナイザと、Wav2Vec2FeatureExtractorと呼ばれる特徴抽出器が付属しています。
まず、トークナイザを作成して、予測された出力クラスを出力の転写にデコードします。
Wav2Vec2CTCTokenizer
の作成
事前学習済みのXLS-Rモデルは、上記の図で示されるように音声信号をコンテキスト表現のシーケンスにマッピングします。ただし、音声認識では、このコンテキスト表現のシーケンスを対応する転写にマッピングする必要があります。つまり、変換器ブロックの上に線形層を追加する必要があります(上図の黄色で示されています)。この線形層は、各コンテキスト表現をトークンクラスに分類するために使用されます。これは、事前トレーニング時にBERTの埋め込みの上に線形層が追加されるのと同様です(以下のブログ記事の「BERT」セクションと比較してください)。この線形層の出力サイズは、語彙中のトークン数に対応し、XLS-Rの事前学習タスクには依存せず、ファインチューニングに使用されるラベル付きデータセットのみに依存します。したがって、最初のステップでは、Common Voiceの選択したデータセットを調べ、転写に基づいて語彙を定義します。
まず、Common Voiceの公式ウェブサイトに移動し、XLS-Rをファインチューニングする言語を選択します。このノートブックでは、トルコ語を使用します。
各言語固有のデータセットには、選択した言語に対応する言語コードがあります。Common Voiceでは、「Version」というフィールドを探します。言語コードは、アンダースコアの前の接頭辞に対応します。たとえば、トルコ語の場合、言語コードは"tr"
です。
素晴らしいですね、では、🤗 DatasetsのシンプルなAPIを使用してデータをダウンロードします。データセット名は"common_voice"
で、構成名は言語コードに対応します。この場合は"tr"
です。
Common Voiceには、invalidated
というさまざまなスプリットがあります。これは、「十分にクリーンでない」と評価されなかったデータを指します。このノートブックでは、"train"
、"validation"
、"test"
のスプリットのみを使用します。
トルコのデータセットは非常に小さいため、バリデーションとトレーニングデータを統合してトレーニングデータセットとし、テストデータのみをバリデーションに使用します。
from datasets import load_dataset, load_metric, Audio
common_voice_train = load_dataset("common_voice", "tr", split="train+validation")
common_voice_test = load_dataset("common_voice", "tr", split="test")
多くのASRデータセットは、各オーディオ配列'audio'
とファイル'path'
に対してターゲットテキスト'sentence'
のみを提供します。Common Voiceは、オーディオファイルに関するさまざまな情報('accent'
など)を提供します。できるだけ一般的なノートブックにするため、ファインチューニングには転写されたテキストのみを考慮します。
common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
データセットのランダムなサンプルを表示するための短い関数を作成し、いくつかの回数実行して、転写の感触を得ましょう。
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML
def show_random_elements(dataset, num_examples=10):
assert num_examples <= len(dataset), "データセットの要素数より多くの要素を選ぶことはできません。"
picks = []
for _ in range(num_examples):
pick = random.randint(0, len(dataset)-1)
while pick in picks:
pick = random.randint(0, len(dataset)-1)
picks.append(pick)
df = pd.DataFrame(dataset[picks])
display(HTML(df.to_html()))
出力結果を表示:
よし!転写はかなりきれいに見えます。転写された文を翻訳した結果、言語はノイズの多い対話よりも書きことに対応しているようです。これは、Common Voiceがクラウドソーシングされた読み上げ音声コーパスであるためです。
転写には,.?!;:
などの特殊文字が含まれていることがわかります。これらの特殊文字は、特徴的な音声単位に対応していないため、言語モデルがないと特殊文字を音声チャンクに分類することはより困難です。たとえば、文字"s"
には比較的明確な音がありますが、特殊文字"."
にはありません。また、音声信号の意味を理解するためには、通常、特殊文字を転写に含める必要はありません。
単語の意味に貢献せず、音声音響で正確に表現することができないすべての文字を単純に削除し、テキストを正規化しましょう。
import re
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\']'
def remove_special_characters(batch):
batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()
return batch
common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)
処理されたテキストラベルを再度確認しましょう。
show_random_elements(common_voice_train.remove_columns(["path","audio"]))
出力結果を表示:
いい感じですね。転写からほとんどの特殊文字が削除され、すべて小文字に正規化されました。
前処理を最終的に行う前に、対象言語のネイティブスピーカーに相談してテキストをさらに簡素化できるかどうかを確認することは常に有利です。このブログ投稿では、Merveさんが手早く見てくれて、「帽子のついた」文字(â
など)はトルコ語ではもはや使用されていないため、「帽子のない」等価物(たとえばa
)で置き換えることができると指摘してくれました。
これは、"yargı sistemi hâlâ sağlıksız"
のような文を"yargı sistemi hala sağlıksız"
に置き換えるべきことを意味します。
さらにテキストラベルを簡素化するための別の短いマッピング関数を作成しましょう。テキストラベルが簡単であればあるほど、モデルがそれらのラベルを予測するのが容易になります。
def replace_hatted_characters(batch):
batch["sentence"] = re.sub('[â]', 'a', batch["sentence"])
batch["sentence"] = re.sub('[î]', 'i', batch["sentence"])
batch["sentence"] = re.sub('[ô]', 'o', batch["sentence"])
batch["sentence"] = re.sub('[û]', 'u', batch["sentence"])
return batch
common_voice_train = common_voice_train.map(replace_hatted_characters)
common_voice_test = common_voice_test.map(replace_hatted_characters)
CTCでは、音声チャンクを文字に分類することが一般的ですので、ここでも同じことを行います。トレーニングデータとテストデータのすべての異なる文字を抽出し、その文字の集合からボキャブラリーを構築します。
すべての転写を1つの長い転写に連結し、その文字列を文字のセットに変換するマッピング関数を作成します。マッピング関数が一度にすべての転写にアクセスできるように、map(...)
関数にbatched=True
引数を渡すことが重要です。
def extract_all_chars(batch):
all_text = " ".join(batch["sentence"])
vocab = list(set(all_text))
return {"vocab": [vocab], "all_text": [all_text]}
vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)
今、我々はトレーニングデータセットとテストデータセットのすべての異なる文字の和集合を作成し、その結果のリストを列挙された辞書に変換します。
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict
出力結果を表示:
{
' ': 0,
'a': 1,
'b': 2,
'c': 3,
'd': 4,
'e': 5,
'f': 6,
'g': 7,
'h': 8,
'i': 9,
'j': 10,
'k': 11,
'l': 12,
'm': 13,
'n': 14,
'o': 15,
'p': 16,
'q': 17,
'r': 18,
's': 19,
't': 20,
'u': 21,
'v': 22,
'w': 23,
'x': 24,
'y': 25,
'z': 26,
'ç': 27,
'ë': 28,
'ö': 29,
'ü': 30,
'ğ': 31,
'ı': 32,
'ş': 33,
'̇': 34
}
素晴らしい、このデータセットにはアルファベットのすべての文字が含まれていることがわかります(それは実際には驚くべきことではありません)また、特殊文字""
と'
も抽出しました。これらの特殊文字を除外しなかった理由は次の通りです:
モデルは単語が終わったときを予測することを学習しなければならず、そうでなければモデルの予測は常に文字のシーケンスになり、単語を区切ることができなくなります。
モデルを訓練する前に前処理は非常に重要なステップであることを常に念頭に置くべきです。例えば、データを正規化するのを忘れたためにa
とA
を区別したくないとします。 a
とA
の違いは、文字の「音」には全く依存せず、むしろ文法的なルールに依存します。例えば、文の先頭に大文字の文字を使用します。したがって、大文字と小文字の文字の違いを取り除くことで、モデルが音声を転写することを学ぶのがより容易になります。
「 」に独自のトークンクラスを持たせるために、より目立つ文字「|」を与えます。さらに、Common Voiceのトレーニングセットで遭遇しなかった文字に対処できるように、「unknown」トークンも追加します。
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
最後に、CTCの「ブランクトークン」に対応するパディングトークンも追加します。 「ブランクトークン」はCTCアルゴリズムの中核です。詳細については、こちらの「Alignment」セクションをご覧ください。
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)
素晴らしい、これでボキャブラリーが完成し、39のトークンで構成されることがわかります。このため、事前学習済みのXLS-Rチェックポイントの上に追加する線形層の出力次元は39になります。
それでは、ボキャブラリーをjsonファイルとして保存しましょう。
import json
with open('vocab.json', 'w') as vocab_file:
json.dump(vocab_dict, vocab_file)
最後のステップとして、jsonファイルを使用してWav2Vec2CTCTokenizer
クラスのインスタンスにボキャブラリーをロードします。
from transformers import Wav2Vec2CTCTokenizer
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
このノートブックのファインチューニングモデルと一緒に作成されたトークナイザーを再利用したい場合は、tokenizer
をHugging Face Hubにアップロードすることを強くお勧めします。ファイルをアップロードするリポジトリの名前を"wav2vec2-large-xlsr-turkish-demo-colab"
としましょう。
repo_name = "wav2vec2-large-xls-r-300m-tr-colab"
そして、トークナイザーを🤗 Hubにアップロードします。
tokenizer.push_to_hub(repo_name)
素晴らしいですね、作成したリポジトリはhttps://huggingface.co/<your-username>/wav2vec2-large-xls-r-300m-tr-colab
の下に表示されます。
Wav2Vec2FeatureExtractor
の作成
音声は連続的な信号であり、コンピュータが処理するためにはまず離散化する必要があります。これは通常サンプリングと呼ばれます。ここでサンプリングレートは重要な役割を果たし、音声信号の秒間のデータポイントの数を定義します。したがって、サンプリングレートが高いほど、実際の音声信号により近い近似が得られますが、秒間の値も増える必要があります。
事前学習済みのチェックポイントは、入力データがモデルの事前学習に使用されたデータとほぼ同じ分布からサンプリングされたものであることを期待しています。サンプリングレートが異なる2つの異なる速度でサンプリングされた同じ音声信号は、非常に異なる分布を持ちます。例えば、サンプリングレートを倍にすると、データポイントが2倍になります。したがって、音声認識モデルの事前学習済みチェックポイントをファインチューニングする前に、モデルの事前学習に使用されたデータのサンプリングレートが、モデルのファインチューニングに使用されるデータのサンプリングレートと一致していることを確認することが重要です。
XLS-Rは、Babel、Multilingual LibriSpeech(MLS)、Common Voice、VoxPopuli、およびVoxLingua107のオーディオデータを16kHzのサンプリングレートで事前学習しています。元の形式のCommon Voiceは、サンプリングレートが48kHzですので、以下ではファインチューニングデータを16kHzにダウンサンプリングする必要があります。
Wav2Vec2FeatureExtractor
オブジェクトのインスタンス化には、次のパラメータが必要です:
feature_size
:音声モデルは、入力として特徴ベクトルのシーケンスを取ります。このシーケンスの長さは明らかに異なりますが、特徴のサイズは同じである必要があります。Wav2Vec2の場合、特徴のサイズは1です。なぜなら、モデルは生の音声信号を2の乗数の長さで訓練されたからです。sampling_rate
:モデルが訓練されたサンプリングレートです。padding_value
:バッチ推論では、短い入力は特定の値でパディングする必要があります。do_normalize
:入力をゼロ平均単位分散正規化するかどうか。通常、音声モデルは入力を正規化すると性能が向上します。return_attention_mask
:モデルがバッチ推論でattention_mask
を使用するかどうか。一般的に、XLS-Rモデルのチェックポイントではattention_mask
を必ず使用する必要があります。
from transformers import Wav2Vec2FeatureExtractor
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
素晴らしいですね、XLS-Rの特徴抽出パイプラインは完全に定義されました!
ユーザーフレンドリーさを向上させるために、特徴抽出器とトークナイザーは単一のWav2Vec2Processor
クラスにラップされていますので、model
とprocessor
オブジェクトのみが必要です。
from transformers import Wav2Vec2Processor
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
次に、データセットを準備します。
データの前処理
これまでは、音声信号の実際の値ではなく、転写のみを見てきました。私たちのデータセットには、sentence
以外にもpath
とaudio
という2つの列名が含まれています。path
は音声ファイルの絶対パスを示しています。見てみましょう。
common_voice_train[0]["path"]
XLS-Rは、1次元の16kHzの配列形式での入力を期待しています。つまり、オーディオファイルをロードしてリサンプリングする必要があります。
幸いにも、datasets
はこれを自動的に行います。他の列audio
を呼び出すことで試してみましょう。
common_voice_train[0]["audio"]
{'array': array([ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
-8.8930130e-05, -3.8027763e-05, -2.9146671e-05], dtype=float32),
'path': '/root/.cache/huggingface/datasets/downloads/extracted/05be0c29807a73c9b099873d2f5975dae6d05e9f7d577458a2466ecb9a2b0c6b/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_21921195.mp3',
'sampling_rate': 48000}
素晴らしいですね、オーディオファイルが自動的にロードされました。これは、datasets == 1.18.3
で導入された新機能である"Audio"
によって可能になりました。この機能は呼び出し時にオーディオファイルを動的にロードしてリサンプリングします。
上記の例では、オーディオデータが48kHzのサンプリングレートでロードされていますが、モデルでは16kHzが期待されています。正しいサンプリングレートでオーディオ機能を設定するには、cast_column
を使用します。
common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16_000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))
では、再び"audio"
を見てみましょう。
common_voice_train[0]["audio"]
{'array': array([ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
-7.4556941e-05, -1.4621433e-05, -5.7861507e-05], dtype=float32),
'path': '/root/.cache/huggingface/datasets/downloads/extracted/05be0c29807a73c9b099873d2f5975dae6d05e9f7d577458a2466ecb9a2b0c6b/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_21921195.mp3',
'sampling_rate': 16000}
うまくいったようですね!データが正しくロードされ、リサンプリングされたようです。
データセットをよりよく理解し、オーディオが正しくロードされたか確認するために、いくつかのオーディオファイルを聴いてみましょう。
import IPython.display as ipd
import numpy as np
import random
rand_int = random.randint(0, len(common_voice_train)-1)
print(common_voice_train[rand_int]["sentence"])
ipd.Audio(data=common_voice_train[rand_int]["audio"]["array"], autoplay=True, rate=16000)
出力結果:
sunulan bütün teklifler i̇ngilizce idi
データが正しくロードされ、リサンプリングされたようです。
話者は話す速度やアクセント、背景環境などに応じて変化しますが、全体的にはクリアな音声が収録されているようです。これは、クラウドソーシングされた読み上げ音声コーパスから期待されるものです。
最後に、データが正しく準備されているかを確認するために、音声入力の形状、転写テキスト、および対応するサンプリングレートを表示してみましょう。
rand_int = random.randint(0, len(common_voice_train)-1)
print("目標テキスト:", common_voice_train[rand_int]["sentence"])
print("入力配列の形状:", common_voice_train[rand_int]["audio"]["array"].shape)
print("サンプリングレート:", common_voice_train[rand_int]["audio"]["sampling_rate"])
出力結果:
目標テキスト: makedonya bu yıl otuz adet tyetmiş iki tankı aldı
入力配列の形状: (71040,)
サンプリングレート: 16000
素晴らしい!すべてが正常に見えます。データは1次元の配列であり、サンプリングレートは常に16kHzに対応しており、目標テキストは正規化されています。
最後に、Wav2Vec2ForCTC
のトレーニングに必要な形式にデータを処理するためにWav2Vec2Processor
を活用しましょう。これを行うには、Datasetのmap(...)
関数を使用します。
最初に、単純に batch["audio"]
を呼び出すことでオーディオデータをロードしてリサンプルします。次に、ロードされたオーディオファイルから input_values
を抽出します。この場合、Wav2Vec2Processor
はデータを単に正規化するだけです。しかし、他の音声モデルの場合、このステップには Log-Mel 特徴量の抽出など、より複雑な特徴量の抽出が含まれる場合もあります。三番目に、トランスクリプションをラベルIDにエンコードします。
注意:このマッピング関数は、Wav2Vec2Processor
クラスの使用方法の良い例です。通常のコンテキストでは、processor(...)
を呼び出すと、Wav2Vec2FeatureExtractor
の呼び出しメソッドにリダイレクトされます。ただし、プロセッサを as_target_processor
コンテキストにラップすると、同じメソッドが Wav2Vec2CTCTokenizer
の呼び出しメソッドにリダイレクトされます。詳細については、ドキュメントをご覧ください。
def prepare_dataset(batch):
audio = batch["audio"]
# バッチ出力は「アンバッチ」されます
batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
batch["input_length"] = len(batch["input_values"])
with processor.as_target_processor():
batch["labels"] = processor(batch["sentence"]).input_ids
return batch
データの準備関数をすべての例に適用しましょう。
common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)
注意:現在、datasets
はオーディオのロードとリサンプリングに torchaudio
と librosa
を使用しています。独自のカスタマイズされたデータのロード/サンプリングを実装したい場合は、単に "path"
列を使用し、"audio"
列を無視してください。
長い入力シーケンスは多くのメモリを必要とします。XLS-R は self-attention
に基づいています。長い入力シーケンスの場合、入力長に対してメモリ要件が二次的にスケーリングします(redditの投稿を参照)。このデモが「メモリ不足」エラーでクラッシュする場合は、以下の行のコメント解除して、トレーニングにおいて5秒より長いシーケンスをすべてフィルタリングするようにしてください。
#max_input_length_in_sec = 5.0
#common_voice_train = common_voice_train.filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])
素晴らしい、これでトレーニングを開始する準備が整いました!
トレーニング
データが処理されたので、トレーニングパイプラインの設定を開始する準備が整いました。トレーニングには 🤗 の Trainer を利用します。以下の手順が必要です。
-
データコレータを定義します。通常の NLP モデルとは異なり、XLS-R は入力長が出力長よりもはるかに大きいです。例えば、入力長が50000のサンプルの出力長は最大100以下です。大きな入力サイズでは、トレーニングバッチを動的にパディングすることが効率的です。つまり、すべてのトレーニングサンプルは、バッチ内の最長サンプルにのみパディングされ、全体の最長サンプルにはパディングされません。そのため、XLS-R のファインチューニングには特殊なパディングデータコレータが必要です。以下で定義します。
-
評価指標。トレーニング中、モデルは単語エラーレートで評価されるべきです。対応する
compute_metrics
関数を定義する必要があります。 -
学習済みのチェックポイントをロードします。学習済みのチェックポイントをロードし、正しくトレーニングするために構成する必要があります。
-
トレーニングの設定を定義します。
モデルをファインチューニングした後、テストデータで正しく評価し、音声の正確な転写を学習したことを確認します。
トレーナーの設定
まず、データコレータを定義しましょう。データコレータのコードは、この例からコピーされました。
詳細には立ち入らずに、一般的なデータコレータとは異なり、このデータコレータは input_values
と labels
を異なる方法で処理し、それぞれに異なるパディング関数を適用します(再び XLS-R プロセッサのコンテキストマネージャを利用しています)。これは、音声の入力と出力が異なるモダリティであるため、同じパディング関数で処理すべきではないためです。一般的なデータコレータと同様に、ラベルのパディングトークンには -100
を使用してラベルの損失計算時に考慮されないようにします。
import torch
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
@dataclass
class DataCollatorCTCWithPadding:
"""
入力を動的にパディングするデータコレータ。
Args:
processor (:class:`~transformers.Wav2Vec2Processor`)
データの処理に使用するプロセッサ。
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
返されるシーケンスをパディングする戦略を選択します(モデルのパディング側とパディングインデックスに従って):
* :obj:`True` または :obj:`'longest'`:バッチ内の最長シーケンスにパディングします(単一のシーケンスの場合はパディングなし)。
* :obj:`'max_length'`:引数 :obj:`max_length` で指定された最大長にパディングします。引数が指定されていない場合は、モデルの最大入力長にパディングします。
* :obj:`False` または :obj:`'do_not_pad'`(デフォルト):パディングなし(つまり、異なる長さのシーケンスを含むバッチを出力できます)。
"""
processor: Wav2Vec2Processor
padding: Union[bool, str] = True
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# 入力とラベルを異なる長さに分割し、異なるパディング方法が必要です
input_features = [{"input_values": feature["input_values"]} for feature in features]
label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.pad(
input_features,
padding=self.padding,
return_tensors="pt",
)
with self.processor.as_target_processor():
labels_batch = self.processor.pad(
label_features,
padding=self.padding,
return_tensors="pt",
)
# ロスを正しく無視するためにパディングを -100 に置換します
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
batch["labels"] = labels
return batch
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
次に、評価メトリックを定義します。先述の通り、ASRで主要なメトリックは単語エラーレート(WER)ですので、このノートブックでも使用します。
wer_metric = load_metric("wer")
モデルは、ログオッズベクトルのシーケンスを返します: y 1 , … , y m \mathbf{y}_1, \ldots, \mathbf{y}_m y 1 , … , y m ただし y 1 = f θ ( x 1 , … , x n ) [ 0 ] \mathbf{y}_1 = f_{\theta}(x_1, \ldots, x_n)[0] y 1 = f θ ( x 1 , … , x n ) [ 0 ] および n > > m n >> m n > > m です。
ログオッズベクトル y 1 \mathbf{y}_1 y 1 は、前述で定義した語彙の各単語の対数オッズを含んでいるため、len ( y i ) = \text{len}(\mathbf{y}_i) = len ( y i ) = config.vocab_size
です。モデルの最も確からしい予測に興味があるため、ログオッズの argmax(...)
を取ります。また、エンコードされたラベルを元の文字列に戻すために、-100
を pad_token_id
で置き換え、CTCスタイル 1 {}^1 1 で連続するトークンを同じトークンとしてグループ化しないようにして、ID をデコードします。
def compute_metrics(pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
# メトリックを計算する際にトークンをグループ化したくないので
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
wer = wer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
それでは、Wav2Vec2-XLS-R-300M の事前学習済みチェックポイントを読み込むことができます。トークナイザの pad_token_id
はモデルの pad_token_id
または Wav2Vec2ForCTC
の場合は CTC のブランクトークン 2 {}^2 2 を定義する必要があります。GPUメモリを節約するために、PyTorch の勾配チェックポイントを有効にし、ロスの削減を ” mean ” に設定します。
データセットが非常に小さい(トレーニングデータが約6時間分)であり、Common Voiceはかなりノイズが多いため、Facebookのwav2vec2-xls-r-300mチェックポイントの微調整にはいくつかのハイパーパラメータの調整が必要なようです。そのため、dropout、SpecAugmentのマスキングドロップアウト率、レイヤードロップアウト、および学習率のさまざまな値を試して、トレーニングが十分に安定するまで調整しました。
注意: このノートブックを使用してCommon Voiceの別の言語でXLS-Rをトレーニングする場合、これらのハイパーパラメータ設定はうまく機能しないかもしれません。使用ケースに応じて適応してください。
from transformers import Wav2Vec2ForCTC
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-xls-r-300m",
attention_dropout=0.0,
hidden_dropout=0.0,
feat_proj_dropout=0.0,
mask_time_prob=0.05,
layerdrop=0.0,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer),
)
XLS-Rの最初のコンポーネントは、生の音声信号から音響的に意味のあるが文脈に依存しないフィーチャーを抽出するために使用されるCNNレイヤーのスタックで構成されています。このモデルのこの部分は事前トレーニング中に十分にトレーニングされており、転移学習の必要はないとされています。したがって、フィーチャー抽出部のすべてのパラメータのrequires_grad
をFalse
に設定できます。
model.freeze_feature_extractor()
最後のステップでは、トレーニングに関連するすべてのパラメータを定義します。いくつかのパラメータの詳細について説明します:
group_by_length
は、入力の長さが似ているトレーニングサンプルを1つのバッチにグループ化することで、トレーニングをより効率的に行います。これにより、モデルを通過する無駄なパディングトークンの総数を大幅に削減することで、トレーニング時間を大幅に短縮できます。learning_rate
およびweight_decay
は、微調整が安定するまで経験的に調整されました。これらのパラメータはCommon Voiceデータセットに強く依存しており、他の音声データセットには最適ではない可能性があります。
その他のパラメータの詳細については、ドキュメントを参照してください。
トレーニング中に、チェックポイントは400ステップごとに非同期でHubにアップロードされます。これにより、モデルがまだトレーニング中でもデモウィジェットで遊ぶことができます。
注意: モデルのチェックポイントをHubにアップロードしたくない場合は、push_to_hub=False
に設定します。
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir=repo_name,
group_by_length=True,
per_device_train_batch_size=16,
gradient_accumulation_steps=2,
evaluation_strategy="steps",
num_train_epochs=30,
gradient_checkpointing=True,
fp16=True,
save_steps=400,
eval_steps=400,
logging_steps=400,
learning_rate=3e-4,
warmup_steps=500,
save_total_limit=2,
push_to_hub=True,
)
これで、すべてのインスタンスをTrainerに渡してトレーニングを開始する準備ができました!
from transformers import Trainer
trainer = Trainer(
model=model,
data_collator=data_collator,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=common_voice_train,
eval_dataset=common_voice_test,
tokenizer=processor.feature_extractor,
)
1 {}^1 1 モデルが話者の速度に依存しないようにするために、CTCでは、連続する同一のトークンは単一のトークンとしてグループ化されます。ただし、デコード時にはエンコードされたラベルをグループ化してはいけません。なぜなら、それらはモデルの予測トークンに対応していないからです。そのため、group_tokens=False
パラメータを渡す必要があります。このパラメータを渡さない場合、"hello"
のような単語が正しくエンコードされ、"helo"
としてデコードされてしまいます。 2 {}^2 2 ブランクトークンは、モデルが2つのlの間にブランクトークンを挿入することで、"hello"
などの単語を予測できるようにします。モデルによるCTC準拠の"hello"
の予測は[PAD] [PAD] "h" "e" "e" "l" "l" [PAD] "l" "o" "o" [PAD]
となります。
トレーニング
トレーニングには、このノートブックに割り当てられたGPUに応じて複数時間かかる場合があります。トレーニングされたモデルは、トルコ語のCommon Voiceのテストデータにおいて、ある程度満足のいく結果を出しますが、最適に調整されたモデルではありません。このノートブックの目的は、ASRデータセットでXLS-R XLSR-Wav2Vec2を微調整する方法を示すことです。
Google Colabに割り当てられたGPUによっては、ここで「メモリ不足」のエラーが表示される可能性があります。その場合は、per_device_train_batch_size
を8またはそれ以下に減らし、gradient_accumulation
を増やすことが最善です。
trainer.train()
出力結果:
トレーニングの損失と検証のWERがうまく減少しています。
トレーニングの結果をHubにアップロードすることができます。次の命令を実行してください:
trainer.push_to_hub()
これで、このモデルをあなたの友人、家族、お気に入りのペットと共有することができます。彼らは、”your-username/the-name-you-picked”という識別子でそれをロードすることができます。たとえば:
from transformers import AutoModelForCTC, Wav2Vec2Processor
model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xls-r-300m-tr-colab")
processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-large-xls-r-300m-tr-colab")
XLS-Rを微調整する他の例については、公式の🤗 Transformersの例をご覧ください。
評価
最終的なチェックとして、モデルをロードし、トルコ語の音声を正確に変換できたかを確認しましょう。
まず、事前学習済みのチェックポイントをロードしましょう。
model = Wav2Vec2ForCTC.from_pretrained(repo_name).to("cuda")
processor = Wav2Vec2Processor.from_pretrained(repo_name)
次に、テストセットの最初の例を取り出し、モデルを通して実行し、ロジットのargmax(...)
を取得して予測されたトークンIDを取得します。
input_dict = processor(common_voice_test[0]["input_values"], return_tensors="pt", padding=True)
logits = model(input_dict.input_values.to("cuda")).logits
pred_ids = torch.argmax(logits, dim=-1)[0]
この関数にsampling_rate
引数を渡すことを強くお勧めします。それを行わないと、デバッグが困難なエラーが発生する可能性があります。
common_voice_test
をかなり修正したため、データセットインスタンスには元の文のラベルが含まれていないことに注意してください。したがって、最初の例のラベルを取得するために元のデータセットを再利用します。
common_voice_test_transcription = load_dataset("common_voice", "tr", data_dir="./cv-corpus-6.1-2020-12-11", split="test")
最後に、例をデコードできます。
print("Prediction:")
print(processor.decode(pred_ids))
print("\nReference:")
print(common_voice_test_transcription[0]["sentence"].lower())
出力結果:
すばらしい!予測からは、トランスクリプションが確かに認識できますが、まだ完璧ではありません。モデルをもう少し長くトレーニングし、データの前処理にもっと時間をかけ、特にデコーディングに言語モデルを使用すると、モデルの全体的なパフォーマンスが向上するでしょう。
低リソース言語のデモモデルとしては、結果はかなり受け入れられるものです🤗。
We will continue to update VoAGI; if you have any questions or suggestions, please contact us!
Was this article helpful?
93 out of 132 found this helpful
Related articles