PyTorch / XLA TPUsでのHugging Face
Hugging Face on PyTorch / XLA TPUs.
お気に入りのトランスフォーマーをPyTorch / XLAを使用してCloud TPUsでトレーニングする
PyTorch-TPUプロジェクトは、Facebook PyTorchチームとGoogle TPUチームの共同作業として始まり、2019年のPyTorch Developer Conference 2019で正式に開始されました。それ以来、私たちはHugging Faceチームと協力して、PyTorch / XLAを使用してCloud TPUsでトレーニングをサポートするための一流のサポートを提供してきました。この新しい統合により、PyTorchユーザーはHugging Faceトレーナーインターフェースをそのまま維持しながら、Cloud TPUs上でモデルを実行しスケーリングすることができます。
このブログ記事では、Hugging Faceライブラリで行われた変更の概要、PyTorch / XLAライブラリの機能、Cloud TPUsでお気に入りのトランスフォーマーをトレーニングするための例、およびいくつかのパフォーマンスベンチマークについて説明します。TPUsで始めるのが待ちきれない場合は、「Cloud TPUsでトランスフォーマーをトレーニングする」セクションにスキップしてください – 私たちはTrainer
モジュール内でPyTorch / XLAのメカニクスをすべて処理します!
XLA:TPUデバイスタイプ
PyTorch / XLAは、PyTorchに新しいxla
デバイスタイプを追加します。このデバイスタイプは他のPyTorchデバイスタイプと同様に動作します。例えば、以下にXLAテンソルを作成して表示する方法が示されています。
- Huggingface TransformersとRayを使用した検索増強生成
- シンプルな人々が派手なニューラルネットワークを構築するための簡単な考慮事項
- ハギングフェイスの読書会、2021年2月 – Long-range Transformers
import torch
import torch_xla
import torch_xla.core.xla_model as xm
t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)
このコードはおなじみのものに見えるかもしれません。PyTorch / XLAは通常のPyTorchと同じインタフェースを使用しますが、いくつかの追加機能も提供しています。 torch_xla
をインポートすると、PyTorch / XLAが初期化され、xm.xla_device()
は現在のXLAデバイスを返します。環境によってはCPU、GPU、またはTPUのいずれかですが、このブログ記事では主にTPUに焦点を当てます。
Trainer
モジュールでは、トレーニングの詳細を定義するためにTrainingArguments
データクラスを利用します。これにはバッチサイズ、学習率、勾配の蓄積などの複数の引数や使用するデバイスが含まれます。上記のコードでは、XLA:TPUデバイスを使用する場合、TrainingArguments._setup_devices()
内で単純にTPUデバイスをTrainer
に返します。
@dataclass
class TrainingArguments:
...
@cached_property
@torch_required
def _setup_devices(self) -> Tuple["torch.device", int]:
...
elif is_torch_tpu_available():
device = xm.xla_device()
n_gpu = 0
...
return device, n_gpu
XLAデバイスによるステップの計算
XLA:TPUのトレーニングシナリオでは、複数のTPUコアで並列にトレーニングを行います(1つのCloud TPUデバイスには8つのTPUコアが含まれています)。そのため、すべての勾配がデータ並列のレプリカ間で交換され、勾配の統合と最適化ステップが行われる必要があります。このために、xm.optimizer_step(optimizer)
を提供しています。Hugging Faceトレーナーでは、トレーニングステップを更新してPyTorch / XLAのAPIを使用します。
class Trainer:
…
def train(self, *args, **kwargs):
...
if is_torch_tpu_available():
xm.optimizer_step(self.optimizer)
PyTorch / XLA入力パイプライン
PyTorch / XLAモデルを実行するには、2つの主要な部分があります:(1)モデルのグラフのトレースと実行を遅延させること(詳細な説明については、「PyTorch / XLAライブラリ」セクションを参照してください)、および(2)モデルにデータを供給することです。最適化なしでモデルのトレース/実行と入力供給を直列に実行すると、ホストCPUとTPUアクセラレータがそれぞれアイドル状態になる時間が発生します。これを回避するために、APIを提供しています。これにより、ステップnが実行中である間にステップn + 1のトレースを重ねることができます。
import torch_xla.distributed.parallel_loader as pl
...
dataloader = pl.MpDeviceLoader(dataloader, device)
チェックポイントの書き込みと読み込み
XLAデバイスからテンソルをチェックポイントに保存してから、チェックポイントから読み込むと、元のデバイスに読み込まれます。モデル内のテンソルをチェックポイント化する前に、すべてのテンソルがXLAデバイスではなくCPUデバイス上にあることを確認する必要があります。これにより、テンソルを読み込む際には、CPUデバイスを介して読み込まれ、その後、希望するXLAデバイスに配置する機会が得られます。これを実現するために、xm.save()
APIを提供しています。このAPIは、各ホストのプロセスのうちの1つからのみストレージ場所に書き込むことをすでに担当しています(またはホスト間で共有ファイルシステムを使用している場合は、グローバルに1つのみ)。
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
…
def save_pretrained(self, save_directory):
...
if getattr(self.config, "xla_device", False):
import torch_xla.core.xla_model as xm
if xm.is_master_ordinal():
# Save configuration file
model_to_save.config.save_pretrained(save_directory)
# xm.save takes care of saving only from master
xm.save(state_dict, output_model_file)
class Trainer:
…
def train(self, *args, **kwargs):
...
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(),
os.path.join(output_dir, "optimizer.pt"))
xm.save(self.lr_scheduler.state_dict(),
os.path.join(output_dir, "scheduler.pt"))
PyTorch / XLAライブラリ
PyTorch / XLAは、XLA(XLA Linear Algebra Compiler)を使用して、PyTorchディープラーニングフレームワークをCPU、GPU、およびCloud TPUなどのXLAデバイスに接続するPythonパッケージです。以下の内容の一部は、API_GUIDE.mdでも利用できます。
PyTorch / XLAテンソルは遅延評価されます
XLAテンソルとデバイスを使用するには、わずかなコードの変更が必要です。ただし、XLAテンソルはCPUテンソルやCUDAテンソルと非常に似ていますが、内部構造は異なります。CPUテンソルとCUDAテンソルは即座にまたは即時に操作を開始しますが、XLAテンソルは遅延評価されます。XLAテンソルは、結果が必要になるまで操作をグラフに記録します。このような遅延実行により、XLAは最適化を行うことができます。複数の個別の操作からなるグラフは、最適化された単一の操作に結合される場合があります。
遅延実行は通常、呼び出し元には見えません。PyTorch / XLAは自動的にグラフを構築し、それをXLAデバイスに送信し、XLAデバイスとCPU間でデータのコピー時に同期します。オプティマイザのステップを実行する際には、明示的にCPUとXLAデバイスを同期させるバリアを挿入します。
つまり、model(input)
で順伝播を呼び出し、損失を計算するloss.backward()
を実行し、最適化ステップxm.optimizer_step(optimizer)
を実行すると、すべての操作のグラフがバックグラウンドで構築されます。テンソルを明示的に評価する(例:テンソルのプリントやCPUデバイスへの移動)またはステップをマークする(これはMpDeviceLoader
が反復ごとに行う)と、完全なステップが実行されます。
トレース、コンパイル、実行、そして繰り返し
ユーザーの視点から見ると、PyTorch / XLAで実行されるモデルの典型的なトレーニング手順は、順伝播、逆伝播、およびオプティマイザのステップを実行することです。PyTorch / XLAライブラリの視点からは、少し異なる見え方をします。
ユーザーが順伝播と逆伝播を実行する間、中間表現(IR)グラフがリアルタイムでトレースされます。各ルート/出力テンソルに至るIRグラフは、次のように検査することができます:
>>> import torch
>>> import torch_xla
>>> import torch_xla.core.xla_model as xm
>>> t = torch.tensor(1, device=xm.xla_device())
>>> s = t*t
>>> print(torch_xla._XLAC._get_xla_tensors_text([s]))
IR {
%0 = s64[] prim::Constant(), value=1
%1 = s64[] prim::Constant(), value=0
%2 = s64[] xla::as_strided_view_update(%1, %0), size=(), stride=(), storage_offset=0
%3 = s64[] aten::as_strided(%2), size=(), stride=(), storage_offset=0
%4 = s64[] aten::mul(%3, %3), ROOT=0
}
このライブグラフは、ユーザーのプログラムで順方向および逆方向のパスが実行される間に蓄積され、xm.mark_step()
が呼び出されると(pl.MpDeviceLoader
によって間接的に)、ライブテンソルのグラフが切り取られます。この切り取りは1つのステップの完了を示し、その後、IRグラフはXLAの高レベル操作(HLO)に変換されます。これは、XLAのIR言語です。
このHLOグラフは、TPUバイナリにコンパイルされ、その後、TPUデバイス上で実行されます。ただし、このコンパイルステップはコストがかかる場合があります。通常、単一のステップよりも長い時間がかかるため、ユーザーのプログラムを毎回コンパイルするとオーバーヘッドが高くなります。このため、HLOグラフの一意のハッシュ識別子でキー付けされたコンパイルされたTPUバイナリを格納するキャッシュを持っています。したがって、最初のステップでこのTPUバイナリキャッシュが作成された後、その後のステップでは通常、新しいTPUバイナリを再コンパイルする必要はありません。代わりに、キャッシュから必要なバイナリを簡単に参照できます。
TPUのコンパイルは、ステップの実行時間よりも通常ははるかに遅いため、グラフの形状が変化し続ける場合は、キャッシュミスが発生し、頻繁にコンパイルされます。コンパイルコストを最小限に抑えるために、テンソルの形状を可能な限り静的に保つことをお勧めします。Hugging Faceライブラリの形状は、入力トークンが適切にパディングされている場合、ほとんど静的ですので、トレーニング中はキャッシュが一貫してヒットするはずです。これは、PyTorch / XLAが提供するデバッグツールを使用して確認できます。以下の例では、コンパイルは5回しか実行されていません( CompileTime
)、一方、1220回のステップの間に実行が行われています( ExecuteTime
):
>>> import torch_xla.debug.metrics as met
>>> print(met.metrics_report())
Metric: CompileTime
TotalSamples: 5
Accumulator: 28s920ms153.731us
ValueRate: 092ms152.037us / second
Rate: 0.0165028 / second
Percentiles: 1%=428ms053.505us; 5%=428ms053.505us; 10%=428ms053.505us; 20%=03s640ms888.060us; 50%=03s650ms126.150us; 80%=11s110ms545.595us; 90%=11s110ms545.595us; 95%=11s110ms545.595us; 99%=11s110ms545.595us
Metric: DeviceLockWait
TotalSamples: 1281
Accumulator: 38s195ms476.007us
ValueRate: 151ms051.277us / second
Rate: 4.54374 / second
Percentiles: 1%=002.895us; 5%=002.989us; 10%=003.094us; 20%=003.243us; 50%=003.654us; 80%=038ms978.659us; 90%=192ms495.718us; 95%=208ms893.403us; 99%=221ms394.520us
Metric: ExecuteTime
TotalSamples: 1220
Accumulator: 04m22s555ms668.071us
ValueRate: 923ms872.877us / second
Rate: 4.33049 / second
Percentiles: 1%=045ms041.018us; 5%=213ms379.757us; 10%=215ms434.912us; 20%=217ms036.764us; 50%=219ms206.894us; 80%=222ms335.146us; 90%=227ms592.924us; 95%=231ms814.500us; 99%=239ms691.472us
Counter: CachedCompile
Value: 1215
Counter: CreateCompileHandles
Value: 5
...
Cloud TPUでTransformerをトレーニングする
VMとCloud TPUを構成するには、「Compute Engineインスタンスの設定」と「Cloud TPUリソースの起動」(執筆時のバージョンはpytorch-1.7)のセクションに従ってください。VMとCloud TPUが作成されたら、それらを使用するのはGCE VMにSSH接続し、次のコマンドを実行するだけです。トレーニングを開始します(バッチサイズはv3-8デバイス用で、v2-8ではOOMする可能性があります):
conda activate torch-xla-1.7
export TPU_IP_ADDRESS="ENTER_YOUR_TPU_IP_ADDRESS" # ex. 10.0.0.2
export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
git clone -b v4.2.2 https://github.com/huggingface/transformers.git
cd transformers && pip install .
pip install datasets==1.2.1
python examples/xla_spawn.py \
--num_cores 8 \
examples/language-modeling/run_mlm.py \
--dataset_name wikitext \
--dataset_config_name wikitext-103-raw-v1 \
--max_seq_length 512 \
--pad_to_max_length \
--logging_dir ./tensorboard-metrics \
--cache_dir ./cache_dir \
--do_train \
--do_eval \
--overwrite_output_dir \
--output_dir language-modeling \
--overwrite_cache \
--tpu_metrics_debug \
--model_name_or_path bert-large-uncased \
--num_train_epochs 3 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--save_steps 500000
上記のコードは、約200分以内でトレーニングを完了し、評価のパープレキシティは約3.25になります。
パフォーマンスのベンチマーク
以下の表は、PyTorch / XLAで実行されるv3-8 Cloud TPUシステム(4つのTPU v3チップを含む)で、bert-large-uncasedのトレーニングのパフォーマンスを示しています。すべてのベンチマーク測定に使用されるデータセットは、WikiText103データセットです。また、Hugging Faceのexamplesで提供されるrun_mlm.pyスクリプトを使用します。ワークロードがホストCPUに制約されないようにするために、これらのテストにはn1-standard-96 CPU構成を使用していますが、性能に影響を与えることなく、より小さい構成を使用することもできるかもしれません。
TPUsでのPyTorch / XLAの始め方
始めるには、Hugging Faceのexamplesの「Running on TPUs」セクションを参照してください。APIの詳細な説明については、API_GUIDEを確認してください。また、パフォーマンスのベストプラクティスについては、TROUBLESHOOTINGガイドを参照してください。一般的なPyTorch / XLAの例については、無料のCloud TPUアクセスを提供する以下のColabノートブックを実行してください。GCPで直接実行するには、ドキュメントサイトの「PyTorch」というラベルが付いたチュートリアルをご覧ください。
その他の質問や問題はありますか?以下のリンクからissueや質問をオープンしてください。https://github.com/huggingface/transformers/issues または https://github.com/pytorch/xla/issues
We will continue to update VoAGI; if you have any questions or suggestions, please contact us!
Was this article helpful?
93 out of 132 found this helpful
Related articles