🤗 Transformersを使用して、低リソースASRのためにXLSR-Wav2Vec2を微調整する

使用するのは、🤗 TransformersXLSR-Wav2Vec2を微調整して、低リソースASRに活用します

新着(11/2021):このブログ投稿は、XLSRの後継であるXLS-Rを紹介するように更新されました。

Wav2Vec2は、自動音声認識(ASR)のための事前学習モデルであり、Alexei Baevski、Michael Auli、Alex Conneauによって2020年9月にリリースされました。Wav2Vec2の優れた性能が、ASRの最も人気のある英語データセットであるLibriSpeechで示されるとすぐに、Facebook AIはWav2Vec2の多言語版であるXLSRを発表しました。XLSRはクロスリンガル音声表現を意味し、モデルが複数の言語で有用な音声表現を学習できる能力を指します。

XLSRの後継であるXLS-R(「音声用のXLM-R」という意味)は、Arun Babu、Changhan Wang、Andros Tjandraなどによって2021年11月にリリースされました。XLS-Rは、自己教師付き事前学習のために128の言語で約500,000時間のオーディオデータを使用し、パラメータ数が30億から200億までのサイズで提供されています。事前学習済みのチェックポイントは、🤗 Hubで見つけることができます:

  • Wav2Vec2-XLS-R-300M
  • Wav2Vec2-XLS-R-1B
  • Wav2Vec2-XLS-R-2B

BERTのマスクされた言語モデリング目的と同様に、XLS-Rは自己教師付き事前学習中に特徴ベクトルをランダムにマスクしてからトランスフォーマーネットワークに渡すことで、文脈化された音声表現を学習します(左側の図)。

ファインチューニングでは、事前学習済みネットワークの上に単一の線形層が追加され、音声認識、音声翻訳、音声分類などのラベル付きデータでモデルをトレーニングします(右側の図)。

XLS-Rは、公式論文のTable 3-6、Table 7-10、Table 11-12で、以前の最先端の結果に比べて音声認識、音声翻訳、話者/言語識別の両方で印象的な改善を示しています。

セットアップ

このブログでは、XLS-R(具体的には事前学習済みチェックポイントWav2Vec2-XLS-R-300M)をASRのためにファインチューニングする方法について詳しく説明します。

デモンストレーションの目的で、我々は低リソースなASRデータセットのCommon Voiceでモデルをファインチューニングします。このデータセットには検証済みのトレーニングデータが約4時間しか含まれていません。

XLS-Rは、音声認識や手書き認識など、シーケンス間の問題のためにニューラルネットワークをトレーニングするために使用されるアルゴリズムであるConnectionist Temporal Classification(CTC)を使用してファインチューニングされます。

Awni Hannunによるよく書かれたブログ記事「Sequence Modeling with CTC(2017)」をお勧めします。

始める前に、datasetstransformersをインストールしてください。また、オーディオファイルを読み込むためにtorchaudioと、単語エラーレート(WER)メトリックを使用してファインチューニングされたモデルを評価するためにjiwerも必要です。

!pip install datasets==1.18.3
!pip install transformers==4.11.3
!pip install huggingface_hub==0.1
!pip install torchaudio
!pip install librosa
!pip install jiwer

トレーニング中にトレーニングチェックポイントをHugging Face Hubに直接アップロードすることを強くお勧めします。Hugging Face Hubには統合されたバージョン管理があるため、トレーニング中にモデルチェックポイントが失われることはありません。

これを行うには、Hugging Faceのウェブサイトから認証トークンを保存する必要があります(まだ登録していない場合はここでサインアップしてください)

from huggingface_hub import notebook_login

notebook_login()

出力結果:

    ログインが成功しました
    トークンは /root/.huggingface/token に保存されました

モデルのチェックポイントをアップロードするには、Git-LFSをインストールする必要があります:

apt install git-lfs

1 {}^1 1 論文では、モデルの評価には音素エラーレート(PER)が使用されましたが、ASRでは最も一般的な評価指標は単語エラーレート(WER)です。このノートブックをできるだけ一般的なものにするため、WERを使用してモデルを評価することにしました。

データ、トークナイザ、特徴抽出器の準備

ASRモデルは音声をテキストに変換するため、音声信号をモデルの入力形式(例:特徴ベクトル)に変換する特徴抽出器と、モデルの出力形式をテキストに変換するトークナイザの両方が必要です。

🤗 Transformersでは、XLS-RモデルにはWav2Vec2CTCTokenizerと呼ばれるトークナイザと、Wav2Vec2FeatureExtractorと呼ばれる特徴抽出器が付属しています。

まず、トークナイザを作成して、予測された出力クラスを出力の転写にデコードします。

Wav2Vec2CTCTokenizerの作成

事前学習済みのXLS-Rモデルは、上記の図で示されるように音声信号をコンテキスト表現のシーケンスにマッピングします。ただし、音声認識では、このコンテキスト表現のシーケンスを対応する転写にマッピングする必要があります。つまり、変換器ブロックの上に線形層を追加する必要があります(上図の黄色で示されています)。この線形層は、各コンテキスト表現をトークンクラスに分類するために使用されます。これは、事前トレーニング時にBERTの埋め込みの上に線形層が追加されるのと同様です(以下のブログ記事の「BERT」セクションと比較してください)。この線形層の出力サイズは、語彙中のトークン数に対応し、XLS-Rの事前学習タスクには依存せず、ファインチューニングに使用されるラベル付きデータセットのみに依存します。したがって、最初のステップでは、Common Voiceの選択したデータセットを調べ、転写に基づいて語彙を定義します。

まず、Common Voiceの公式ウェブサイトに移動し、XLS-Rをファインチューニングする言語を選択します。このノートブックでは、トルコ語を使用します。

各言語固有のデータセットには、選択した言語に対応する言語コードがあります。Common Voiceでは、「Version」というフィールドを探します。言語コードは、アンダースコアの前の接頭辞に対応します。たとえば、トルコ語の場合、言語コードは"tr"です。

素晴らしいですね、では、🤗 DatasetsのシンプルなAPIを使用してデータをダウンロードします。データセット名は"common_voice"で、構成名は言語コードに対応します。この場合は"tr"です。

Common Voiceには、invalidatedというさまざまなスプリットがあります。これは、「十分にクリーンでない」と評価されなかったデータを指します。このノートブックでは、"train""validation""test"のスプリットのみを使用します。

トルコのデータセットは非常に小さいため、バリデーションとトレーニングデータを統合してトレーニングデータセットとし、テストデータのみをバリデーションに使用します。

from datasets import load_dataset, load_metric, Audio

common_voice_train = load_dataset("common_voice", "tr", split="train+validation")
common_voice_test = load_dataset("common_voice", "tr", split="test")

多くのASRデータセットは、各オーディオ配列'audio'とファイル'path'に対してターゲットテキスト'sentence'のみを提供します。Common Voiceは、オーディオファイルに関するさまざまな情報('accent'など)を提供します。できるだけ一般的なノートブックにするため、ファインチューニングには転写されたテキストのみを考慮します。

common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])

データセットのランダムなサンプルを表示するための短い関数を作成し、いくつかの回数実行して、転写の感触を得ましょう。

from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "データセットの要素数より多くの要素を選ぶことはできません。"
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

出力結果を表示:

よし!転写はかなりきれいに見えます。転写された文を翻訳した結果、言語はノイズの多い対話よりも書きことに対応しているようです。これは、Common Voiceがクラウドソーシングされた読み上げ音声コーパスであるためです。

転写には,.?!;:などの特殊文字が含まれていることがわかります。これらの特殊文字は、特徴的な音声単位に対応していないため、言語モデルがないと特殊文字を音声チャンクに分類することはより困難です。たとえば、文字"s"には比較的明確な音がありますが、特殊文字"."にはありません。また、音声信号の意味を理解するためには、通常、特殊文字を転写に含める必要はありません。

単語の意味に貢献せず、音声音響で正確に表現することができないすべての文字を単純に削除し、テキストを正規化しましょう。

import re
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\']'

def remove_special_characters(batch):
    batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()
    return batch

common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)

処理されたテキストラベルを再度確認しましょう。

show_random_elements(common_voice_train.remove_columns(["path","audio"]))

出力結果を表示:

いい感じですね。転写からほとんどの特殊文字が削除され、すべて小文字に正規化されました。

前処理を最終的に行う前に、対象言語のネイティブスピーカーに相談してテキストをさらに簡素化できるかどうかを確認することは常に有利です。このブログ投稿では、Merveさんが手早く見てくれて、「帽子のついた」文字(âなど)はトルコ語ではもはや使用されていないため、「帽子のない」等価物(たとえばa)で置き換えることができると指摘してくれました。

これは、"yargı sistemi hâlâ sağlıksız"のような文を"yargı sistemi hala sağlıksız"に置き換えるべきことを意味します。

さらにテキストラベルを簡素化するための別の短いマッピング関数を作成しましょう。テキストラベルが簡単であればあるほど、モデルがそれらのラベルを予測するのが容易になります。

def replace_hatted_characters(batch):
    batch["sentence"] = re.sub('[â]', 'a', batch["sentence"])
    batch["sentence"] = re.sub('[î]', 'i', batch["sentence"])
    batch["sentence"] = re.sub('[ô]', 'o', batch["sentence"])
    batch["sentence"] = re.sub('[û]', 'u', batch["sentence"])
    return batch

common_voice_train = common_voice_train.map(replace_hatted_characters)
common_voice_test = common_voice_test.map(replace_hatted_characters)

CTCでは、音声チャンクを文字に分類することが一般的ですので、ここでも同じことを行います。トレーニングデータとテストデータのすべての異なる文字を抽出し、その文字の集合からボキャブラリーを構築します。

すべての転写を1つの長い転写に連結し、その文字列を文字のセットに変換するマッピング関数を作成します。マッピング関数が一度にすべての転写にアクセスできるように、map(...)関数にbatched=True引数を渡すことが重要です。

def extract_all_chars(batch):
  all_text = " ".join(batch["sentence"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)

今、我々はトレーニングデータセットとテストデータセットのすべての異なる文字の和集合を作成し、その結果のリストを列挙された辞書に変換します。

vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))

vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict

出力結果を表示:

{
 ' ': 0,
 'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'f': 6,
 'g': 7,
 'h': 8,
 'i': 9,
 'j': 10,
 'k': 11,
 'l': 12,
 'm': 13,
 'n': 14,
 'o': 15,
 'p': 16,
 'q': 17,
 'r': 18,
 's': 19,
 't': 20,
 'u': 21,
 'v': 22,
 'w': 23,
 'x': 24,
 'y': 25,
 'z': 26,
 'ç': 27,
 'ë': 28,
 'ö': 29,
 'ü': 30,
 'ğ': 31,
 'ı': 32,
 'ş': 33,
 '̇': 34
}

素晴らしい、このデータセットにはアルファベットのすべての文字が含まれていることがわかります(それは実際には驚くべきことではありません)また、特殊文字""'も抽出しました。これらの特殊文字を除外しなかった理由は次の通りです:

モデルは単語が終わったときを予測することを学習しなければならず、そうでなければモデルの予測は常に文字のシーケンスになり、単語を区切ることができなくなります。

モデルを訓練する前に前処理は非常に重要なステップであることを常に念頭に置くべきです。例えば、データを正規化するのを忘れたためにaAを区別したくないとします。 aAの違いは、文字の「音」には全く依存せず、むしろ文法的なルールに依存します。例えば、文の先頭に大文字の文字を使用します。したがって、大文字と小文字の文字の違いを取り除くことで、モデルが音声を転写することを学ぶのがより容易になります。

「 」に独自のトークンクラスを持たせるために、より目立つ文字「|」を与えます。さらに、Common Voiceのトレーニングセットで遭遇しなかった文字に対処できるように、「unknown」トークンも追加します。

vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

最後に、CTCの「ブランクトークン」に対応するパディングトークンも追加します。 「ブランクトークン」はCTCアルゴリズムの中核です。詳細については、こちらの「Alignment」セクションをご覧ください。

vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)

素晴らしい、これでボキャブラリーが完成し、39のトークンで構成されることがわかります。このため、事前学習済みのXLS-Rチェックポイントの上に追加する線形層の出力次元は39になります。

それでは、ボキャブラリーをjsonファイルとして保存しましょう。

import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

最後のステップとして、jsonファイルを使用してWav2Vec2CTCTokenizerクラスのインスタンスにボキャブラリーをロードします。

from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

このノートブックのファインチューニングモデルと一緒に作成されたトークナイザーを再利用したい場合は、tokenizerをHugging Face Hubにアップロードすることを強くお勧めします。ファイルをアップロードするリポジトリの名前を"wav2vec2-large-xlsr-turkish-demo-colab"としましょう。

repo_name = "wav2vec2-large-xls-r-300m-tr-colab"

そして、トークナイザーを🤗 Hubにアップロードします。

tokenizer.push_to_hub(repo_name)

素晴らしいですね、作成したリポジトリはhttps://huggingface.co/<your-username>/wav2vec2-large-xls-r-300m-tr-colabの下に表示されます。

Wav2Vec2FeatureExtractorの作成

音声は連続的な信号であり、コンピュータが処理するためにはまず離散化する必要があります。これは通常サンプリングと呼ばれます。ここでサンプリングレートは重要な役割を果たし、音声信号の秒間のデータポイントの数を定義します。したがって、サンプリングレートが高いほど、実際の音声信号により近い近似が得られますが、秒間の値も増える必要があります。

事前学習済みのチェックポイントは、入力データがモデルの事前学習に使用されたデータとほぼ同じ分布からサンプリングされたものであることを期待しています。サンプリングレートが異なる2つの異なる速度でサンプリングされた同じ音声信号は、非常に異なる分布を持ちます。例えば、サンプリングレートを倍にすると、データポイントが2倍になります。したがって、音声認識モデルの事前学習済みチェックポイントをファインチューニングする前に、モデルの事前学習に使用されたデータのサンプリングレートが、モデルのファインチューニングに使用されるデータのサンプリングレートと一致していることを確認することが重要です。

XLS-Rは、Babel、Multilingual LibriSpeech(MLS)、Common Voice、VoxPopuli、およびVoxLingua107のオーディオデータを16kHzのサンプリングレートで事前学習しています。元の形式のCommon Voiceは、サンプリングレートが48kHzですので、以下ではファインチューニングデータを16kHzにダウンサンプリングする必要があります。

Wav2Vec2FeatureExtractorオブジェクトのインスタンス化には、次のパラメータが必要です:

  • feature_size:音声モデルは、入力として特徴ベクトルのシーケンスを取ります。このシーケンスの長さは明らかに異なりますが、特徴のサイズは同じである必要があります。Wav2Vec2の場合、特徴のサイズは1です。なぜなら、モデルは生の音声信号を2の乗数の長さで訓練されたからです。
  • sampling_rate:モデルが訓練されたサンプリングレートです。
  • padding_value:バッチ推論では、短い入力は特定の値でパディングする必要があります。
  • do_normalize:入力をゼロ平均単位分散正規化するかどうか。通常、音声モデルは入力を正規化すると性能が向上します。
  • return_attention_mask:モデルがバッチ推論でattention_maskを使用するかどうか。一般的に、XLS-Rモデルのチェックポイントではattention_maskを必ず使用する必要があります。
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)

素晴らしいですね、XLS-Rの特徴抽出パイプラインは完全に定義されました!

ユーザーフレンドリーさを向上させるために、特徴抽出器とトークナイザーは単一のWav2Vec2Processorクラスにラップされていますので、modelprocessorオブジェクトのみが必要です。

from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

次に、データセットを準備します。

データの前処理

これまでは、音声信号の実際の値ではなく、転写のみを見てきました。私たちのデータセットには、sentence以外にもpathaudioという2つの列名が含まれています。pathは音声ファイルの絶対パスを示しています。見てみましょう。

common_voice_train[0]["path"]

XLS-Rは、1次元の16kHzの配列形式での入力を期待しています。つまり、オーディオファイルをロードしてリサンプリングする必要があります。

幸いにも、datasetsはこれを自動的に行います。他の列audioを呼び出すことで試してみましょう。

common_voice_train[0]["audio"]

    {'array': array([ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
            -8.8930130e-05, -3.8027763e-05, -2.9146671e-05], dtype=float32),
     'path': '/root/.cache/huggingface/datasets/downloads/extracted/05be0c29807a73c9b099873d2f5975dae6d05e9f7d577458a2466ecb9a2b0c6b/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_21921195.mp3',
     'sampling_rate': 48000}

素晴らしいですね、オーディオファイルが自動的にロードされました。これは、datasets == 1.18.3で導入された新機能である"Audio"によって可能になりました。この機能は呼び出し時にオーディオファイルを動的にロードしてリサンプリングします。

上記の例では、オーディオデータが48kHzのサンプリングレートでロードされていますが、モデルでは16kHzが期待されています。正しいサンプリングレートでオーディオ機能を設定するには、cast_columnを使用します。

common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16_000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))

では、再び"audio"を見てみましょう。

common_voice_train[0]["audio"]

    {'array': array([ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
            -7.4556941e-05, -1.4621433e-05, -5.7861507e-05], dtype=float32),
     'path': '/root/.cache/huggingface/datasets/downloads/extracted/05be0c29807a73c9b099873d2f5975dae6d05e9f7d577458a2466ecb9a2b0c6b/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_21921195.mp3',
     'sampling_rate': 16000}

うまくいったようですね!データが正しくロードされ、リサンプリングされたようです。

データセットをよりよく理解し、オーディオが正しくロードされたか確認するために、いくつかのオーディオファイルを聴いてみましょう。

import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(common_voice_train)-1)

print(common_voice_train[rand_int]["sentence"])
ipd.Audio(data=common_voice_train[rand_int]["audio"]["array"], autoplay=True, rate=16000)

出力結果:

    sunulan bütün teklifler i̇ngilizce idi

データが正しくロードされ、リサンプリングされたようです。

話者は話す速度やアクセント、背景環境などに応じて変化しますが、全体的にはクリアな音声が収録されているようです。これは、クラウドソーシングされた読み上げ音声コーパスから期待されるものです。

最後に、データが正しく準備されているかを確認するために、音声入力の形状、転写テキスト、および対応するサンプリングレートを表示してみましょう。

rand_int = random.randint(0, len(common_voice_train)-1)

print("目標テキスト:", common_voice_train[rand_int]["sentence"])
print("入力配列の形状:", common_voice_train[rand_int]["audio"]["array"].shape)
print("サンプリングレート:", common_voice_train[rand_int]["audio"]["sampling_rate"])

出力結果:

    目標テキスト: makedonya bu yıl otuz adet tyetmiş iki tankı aldı
    入力配列の形状: (71040,)
    サンプリングレート: 16000

素晴らしい!すべてが正常に見えます。データは1次元の配列であり、サンプリングレートは常に16kHzに対応しており、目標テキストは正規化されています。

最後に、Wav2Vec2ForCTCのトレーニングに必要な形式にデータを処理するためにWav2Vec2Processorを活用しましょう。これを行うには、Datasetのmap(...)関数を使用します。

最初に、単純に batch["audio"] を呼び出すことでオーディオデータをロードしてリサンプルします。次に、ロードされたオーディオファイルから input_values を抽出します。この場合、Wav2Vec2Processor はデータを単に正規化するだけです。しかし、他の音声モデルの場合、このステップには Log-Mel 特徴量の抽出など、より複雑な特徴量の抽出が含まれる場合もあります。三番目に、トランスクリプションをラベルIDにエンコードします。

注意:このマッピング関数は、Wav2Vec2Processor クラスの使用方法の良い例です。通常のコンテキストでは、processor(...) を呼び出すと、Wav2Vec2FeatureExtractor の呼び出しメソッドにリダイレクトされます。ただし、プロセッサを as_target_processor コンテキストにラップすると、同じメソッドが Wav2Vec2CTCTokenizer の呼び出しメソッドにリダイレクトされます。詳細については、ドキュメントをご覧ください。

def prepare_dataset(batch):
    audio = batch["audio"]

    # バッチ出力は「アンバッチ」されます
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    batch["input_length"] = len(batch["input_values"])
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["sentence"]).input_ids
    return batch

データの準備関数をすべての例に適用しましょう。

common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)

注意:現在、datasets はオーディオのロードとリサンプリングに torchaudiolibrosa を使用しています。独自のカスタマイズされたデータのロード/サンプリングを実装したい場合は、単に "path" 列を使用し、"audio" 列を無視してください。

長い入力シーケンスは多くのメモリを必要とします。XLS-R は self-attention に基づいています。長い入力シーケンスの場合、入力長に対してメモリ要件が二次的にスケーリングします(redditの投稿を参照)。このデモが「メモリ不足」エラーでクラッシュする場合は、以下の行のコメント解除して、トレーニングにおいて5秒より長いシーケンスをすべてフィルタリングするようにしてください。

#max_input_length_in_sec = 5.0
#common_voice_train = common_voice_train.filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])

素晴らしい、これでトレーニングを開始する準備が整いました!

トレーニング

データが処理されたので、トレーニングパイプラインの設定を開始する準備が整いました。トレーニングには 🤗 の Trainer を利用します。以下の手順が必要です。

  • データコレータを定義します。通常の NLP モデルとは異なり、XLS-R は入力長が出力長よりもはるかに大きいです。例えば、入力長が50000のサンプルの出力長は最大100以下です。大きな入力サイズでは、トレーニングバッチを動的にパディングすることが効率的です。つまり、すべてのトレーニングサンプルは、バッチ内の最長サンプルにのみパディングされ、全体の最長サンプルにはパディングされません。そのため、XLS-R のファインチューニングには特殊なパディングデータコレータが必要です。以下で定義します。

  • 評価指標。トレーニング中、モデルは単語エラーレートで評価されるべきです。対応する compute_metrics 関数を定義する必要があります。

  • 学習済みのチェックポイントをロードします。学習済みのチェックポイントをロードし、正しくトレーニングするために構成する必要があります。

  • トレーニングの設定を定義します。

モデルをファインチューニングした後、テストデータで正しく評価し、音声の正確な転写を学習したことを確認します。

トレーナーの設定

まず、データコレータを定義しましょう。データコレータのコードは、この例からコピーされました。

詳細には立ち入らずに、一般的なデータコレータとは異なり、このデータコレータは input_valueslabels を異なる方法で処理し、それぞれに異なるパディング関数を適用します(再び XLS-R プロセッサのコンテキストマネージャを利用しています)。これは、音声の入力と出力が異なるモダリティであるため、同じパディング関数で処理すべきではないためです。一般的なデータコレータと同様に、ラベルのパディングトークンには -100 を使用してラベルの損失計算時に考慮されないようにします。

import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    入力を動的にパディングするデータコレータ。
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            データの処理に使用するプロセッサ。
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            返されるシーケンスをパディングする戦略を選択します(モデルのパディング側とパディングインデックスに従って):
            * :obj:`True` または :obj:`'longest'`:バッチ内の最長シーケンスにパディングします(単一のシーケンスの場合はパディングなし)。
            * :obj:`'max_length'`:引数 :obj:`max_length` で指定された最大長にパディングします。引数が指定されていない場合は、モデルの最大入力長にパディングします。
            * :obj:`False` または :obj:`'do_not_pad'`(デフォルト):パディングなし(つまり、異なる長さのシーケンスを含むバッチを出力できます)。
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # 入力とラベルを異なる長さに分割し、異なるパディング方法が必要です
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        # ロスを正しく無視するためにパディングを -100 に置換します
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

次に、評価メトリックを定義します。先述の通り、ASRで主要なメトリックは単語エラーレート(WER)ですので、このノートブックでも使用します。

wer_metric = load_metric("wer")

モデルは、ログオッズベクトルのシーケンスを返します: y 1 , … , y m \mathbf{y}_1, \ldots, \mathbf{y}_m y 1 ​ , … , y m ​ ただし y 1 = f θ ( x 1 , … , x n ) [ 0 ] \mathbf{y}_1 = f_{\theta}(x_1, \ldots, x_n)[0] y 1 ​ = f θ ​ ( x 1 ​ , … , x n ​ ) [ 0 ] および n > > m n >> m n > > m です。

ログオッズベクトル y 1 \mathbf{y}_1 y 1 ​ は、前述で定義した語彙の各単語の対数オッズを含んでいるため、len ( y i ) = \text{len}(\mathbf{y}_i) = len ( y i ​ ) = config.vocab_size です。モデルの最も確からしい予測に興味があるため、ログオッズの argmax(...) を取ります。また、エンコードされたラベルを元の文字列に戻すために、-100pad_token_id で置き換え、CTCスタイル 1 {}^1 1 で連続するトークンを同じトークンとしてグループ化しないようにして、ID をデコードします。

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # メトリックを計算する際にトークンをグループ化したくないので
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

それでは、Wav2Vec2-XLS-R-300M の事前学習済みチェックポイントを読み込むことができます。トークナイザの pad_token_id はモデルの pad_token_id または Wav2Vec2ForCTC の場合は CTC のブランクトークン 2 {}^2 2 を定義する必要があります。GPUメモリを節約するために、PyTorch の勾配チェックポイントを有効にし、ロスの削減を ” mean ” に設定します。

データセットが非常に小さい(トレーニングデータが約6時間分)であり、Common Voiceはかなりノイズが多いため、Facebookのwav2vec2-xls-r-300mチェックポイントの微調整にはいくつかのハイパーパラメータの調整が必要なようです。そのため、dropout、SpecAugmentのマスキングドロップアウト率、レイヤードロップアウト、および学習率のさまざまな値を試して、トレーニングが十分に安定するまで調整しました。

注意: このノートブックを使用してCommon Voiceの別の言語でXLS-Rをトレーニングする場合、これらのハイパーパラメータ設定はうまく機能しないかもしれません。使用ケースに応じて適応してください。

from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-xls-r-300m", 
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.0,
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)

XLS-Rの最初のコンポーネントは、生の音声信号から音響的に意味のあるが文脈に依存しないフィーチャーを抽出するために使用されるCNNレイヤーのスタックで構成されています。このモデルのこの部分は事前トレーニング中に十分にトレーニングされており、転移学習の必要はないとされています。したがって、フィーチャー抽出部のすべてのパラメータのrequires_gradFalseに設定できます。

model.freeze_feature_extractor()

最後のステップでは、トレーニングに関連するすべてのパラメータを定義します。いくつかのパラメータの詳細について説明します:

  • group_by_lengthは、入力の長さが似ているトレーニングサンプルを1つのバッチにグループ化することで、トレーニングをより効率的に行います。これにより、モデルを通過する無駄なパディングトークンの総数を大幅に削減することで、トレーニング時間を大幅に短縮できます。
  • learning_rateおよびweight_decayは、微調整が安定するまで経験的に調整されました。これらのパラメータはCommon Voiceデータセットに強く依存しており、他の音声データセットには最適ではない可能性があります。

その他のパラメータの詳細については、ドキュメントを参照してください。

トレーニング中に、チェックポイントは400ステップごとに非同期でHubにアップロードされます。これにより、モデルがまだトレーニング中でもデモウィジェットで遊ぶことができます。

注意: モデルのチェックポイントをHubにアップロードしたくない場合は、push_to_hub=Falseに設定します。

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=16,
  gradient_accumulation_steps=2,
  evaluation_strategy="steps",
  num_train_epochs=30,
  gradient_checkpointing=True,
  fp16=True,
  save_steps=400,
  eval_steps=400,
  logging_steps=400,
  learning_rate=3e-4,
  warmup_steps=500,
  save_total_limit=2,
  push_to_hub=True,
)

これで、すべてのインスタンスをTrainerに渡してトレーニングを開始する準備ができました!

from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

1 {}^1 1 モデルが話者の速度に依存しないようにするために、CTCでは、連続する同一のトークンは単一のトークンとしてグループ化されます。ただし、デコード時にはエンコードされたラベルをグループ化してはいけません。なぜなら、それらはモデルの予測トークンに対応していないからです。そのため、group_tokens=Falseパラメータを渡す必要があります。このパラメータを渡さない場合、"hello"のような単語が正しくエンコードされ、"helo"としてデコードされてしまいます。 2 {}^2 2 ブランクトークンは、モデルが2つのlの間にブランクトークンを挿入することで、"hello"などの単語を予測できるようにします。モデルによるCTC準拠の"hello"の予測は[PAD] [PAD] "h" "e" "e" "l" "l" [PAD] "l" "o" "o" [PAD]となります。

トレーニング

トレーニングには、このノートブックに割り当てられたGPUに応じて複数時間かかる場合があります。トレーニングされたモデルは、トルコ語のCommon Voiceのテストデータにおいて、ある程度満足のいく結果を出しますが、最適に調整されたモデルではありません。このノートブックの目的は、ASRデータセットでXLS-R XLSR-Wav2Vec2を微調整する方法を示すことです。

Google Colabに割り当てられたGPUによっては、ここで「メモリ不足」のエラーが表示される可能性があります。その場合は、per_device_train_batch_sizeを8またはそれ以下に減らし、gradient_accumulationを増やすことが最善です。

trainer.train()

出力結果:

トレーニングの損失と検証のWERがうまく減少しています。

トレーニングの結果をHubにアップロードすることができます。次の命令を実行してください:

trainer.push_to_hub()

これで、このモデルをあなたの友人、家族、お気に入りのペットと共有することができます。彼らは、”your-username/the-name-you-picked”という識別子でそれをロードすることができます。たとえば:

from transformers import AutoModelForCTC, Wav2Vec2Processor

model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xls-r-300m-tr-colab")
processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-large-xls-r-300m-tr-colab")

XLS-Rを微調整する他の例については、公式の🤗 Transformersの例をご覧ください。

評価

最終的なチェックとして、モデルをロードし、トルコ語の音声を正確に変換できたかを確認しましょう。

まず、事前学習済みのチェックポイントをロードしましょう。

model = Wav2Vec2ForCTC.from_pretrained(repo_name).to("cuda")
processor = Wav2Vec2Processor.from_pretrained(repo_name)

次に、テストセットの最初の例を取り出し、モデルを通して実行し、ロジットのargmax(...)を取得して予測されたトークンIDを取得します。

input_dict = processor(common_voice_test[0]["input_values"], return_tensors="pt", padding=True)

logits = model(input_dict.input_values.to("cuda")).logits

pred_ids = torch.argmax(logits, dim=-1)[0]

この関数にsampling_rate引数を渡すことを強くお勧めします。それを行わないと、デバッグが困難なエラーが発生する可能性があります。

common_voice_testをかなり修正したため、データセットインスタンスには元の文のラベルが含まれていないことに注意してください。したがって、最初の例のラベルを取得するために元のデータセットを再利用します。

common_voice_test_transcription = load_dataset("common_voice", "tr", data_dir="./cv-corpus-6.1-2020-12-11", split="test")

最後に、例をデコードできます。

print("Prediction:")
print(processor.decode(pred_ids))

print("\nReference:")
print(common_voice_test_transcription[0]["sentence"].lower())

出力結果:

すばらしい!予測からは、トランスクリプションが確かに認識できますが、まだ完璧ではありません。モデルをもう少し長くトレーニングし、データの前処理にもっと時間をかけ、特にデコーディングに言語モデルを使用すると、モデルの全体的なパフォーマンスが向上するでしょう。

低リソース言語のデモモデルとしては、結果はかなり受け入れられるものです🤗。

We will continue to update VoAGI; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

人工知能

「UVeyeの共同設立者兼CEO、アミール・ヘヴェルについてのインタビューシリーズ」

アミール・ヘヴァーは、UVeyeのCEO兼共同創設者であり、高速かつ正確な異常検出により、自動車およびセキュリティ産業に直面...

人工知能

「コマンドバーの創設者兼CEO、ジェームズ・エバンスによるインタビューシリーズ」

ジェームズ・エバンズは、CommandBarの創設者兼CEOであり、製品、マーケティング、顧客チームを支援するために設計されたAIパ...

人工知能

「Kognitosの創設者兼CEO、ビニー・ギル- インタビューシリーズ」

ビニー・ギルは、複数の役職と企業を横断する多様で幅広い業務経験を持っていますビニーは現在、Kognitosの創設者兼CEOであり...

データサイエンス

2023年にAmazonのデータサイエンティストになる方法は?

ほとんどのビジネスは現在、膨大な量のデータを生成し、編集し、管理しています。しかし、ほとんどのビジネスは、収集したデ...

人工知能

「ナレ・ヴァンダニャン、Ntropyの共同創設者兼CEO- インタビューシリーズ」

Ntropyの共同創設者兼CEOであるナレ・ヴァンダニアンは、開発者が100ミリ秒未満で超人的な精度で金融取引を解析することを可...

データサイエンス

「3つの質問:ロボットの認識とマッピングの研磨」

MIT LIDSのLuca CarloneさんとJonathan Howさんは、将来のロボットが環境をどのように知覚し、相互作用するかについて議論し...