JAXを使用してRL環境をベクトル化・並列化する:光の速さでのQ学習⚡

JAXを使ってRL環境を効率化・並列化する:光速のQ学習⚡️

この記事では、RL(強化学習)環境のベクトル化とCPU上で30個のQ学習エージェントを平行してトレーニングする方法を学びます。秒間180万回のイテレーション速度で行います。

Google DeepMindの画像、Unsplashから

以前のストーリーでは、Temporal-Difference Learning(TD学習)、特にQ学習について、GridWorldの文脈で紹介しました。

Temporal-Difference Learningと探索の重要性:イラストガイド

動的グリッドワールドでのモデルフリー(Q学習)とモデルベース(Dyna-QとDyna-Q+)TD法の比較

towardsdatascience.com

この実装は、これらのアルゴリズムのパフォーマンスと探索メカニズムの違いをデモンストレーションする目的で使用されましたが、非常に遅いです。

実際、環境とエージェントは主にNumpyでコーディングされていますが、これはRLにおいて標準ではなく、コードを理解しやすくデバッグしやすくするものではありません。

この記事では、環境のベクトル化とエージェントのトレーニングを効率的に行うためのJAXを使用した平行化について見ていきます。具体的には、以下の内容をカバーします:

  • JAXの基礎知識とRLに有用な機能
  • ベクトル化された環境とその高速化の理由
  • JAXでの環境、ポリシー、Q学習エージェントの実装
  • 単一エージェントのトレーニング
  • エージェントトレーニングの並列化方法とその簡易性

本記事で紹介されるすべてのコードはGitHubで入手できます:

GitHub – RPegoud/jax_rl: RLアルゴリズムとベクトル化された環境のJAX実装

JAX実装のRLアルゴリズムとベクトル化された環境 – GitHub – RPegoud/jax_rl: RL…

github.com

JAXの基礎知識

JAXは、Googleによって開発された別のPythonディープラーニングフレームワークで、DeepMindなどの企業でも広く使用されています。

JAXは、Autograd(自動微分)とXLA(加速された線形代数、TensorFlowコンパイラ)が組み合わされた、高パフォーマンスの数値計算のためのものです。- 公式ドキュメント

通常のPython開発者が慣れているものとは異なり、JAXはオブジェクト指向プログラミング(OOP)のパラダイムではなく、むしろ関数型プログラミング(FP)を採用しています。

単純に言えば、JAXは純粋関数決定論的であり、副作用のない)およびイミュータブルデータ構造(データをその場で変更する代わりに、望ましい変更を加えた新しいデータ構造作成する)を主な構築ブロックとしています。その結果、FPはより機能的で数学的なアプローチをプログラミングに促進し、数値計算や機械学習のようなタスクに適しています。

これら2つのパラダイムの違いを、Q-update関数の疑似コードを見ながら説明しましょう。

  • オブジェクト指向のアプローチは、クラスのインスタンスに各種の状態変数(Q値など)を含めます。アップデート関数は、インスタンスの内部状態を更新するクラスメソッドとして定義されます。
  • 関数型プログラミングのアプローチは、純粋関数に依存します。実際に、このQ-update関数は決定論的であり、Q値が引数として渡されるため、同じ入力でこの関数を呼び出すと同じ出力が返ります。一方、クラスメソッドの出力はインスタンスの内部状態に依存する場合があります。また、配列などのデータ構造グローバルスコープで定義および変更されます。
Implementing a Q-update in Object-Oriented Programming and Functional Programming (made by the author)

JAXはRLの文脈で特に有用な、さまざまな関数デコレータを提供しています。

  • vmap (ベクトル化マップ):単一のサンプルに作用する関数をバッチに適用できるようにします。たとえば、env.step()が単一の環境でステップを実行する関数である場合、vmap(env.step)()は複数の環境でステップを実行する関数です。つまり、vmapは関数にバッチ次元を追加します。
Illustration of a step function vectorized using vmap (made by the author)
  • jit (ジャストインタイムコンパイル):JAXが「JAX Python関数のジャストインタイムコンパイル」を実行できるようにします。jitを使用することで、関数をコンパイルし、大幅な速度向上を提供できます(ただし、関数の最初のコンパイル時には追加のオーバーヘッドが発生します)。
  • pmap (パラレルマップ):vmapと同様に、pmapも簡単な並列化を可能にします。ただし、関数にバッチ次元を追加するのではなく、関数を複製して複数のXLAデバイス上で実行します。注意点:pmapを適用する場合、jitも自動的に適用されます。
Illustration of a step function parallelized using pmap (made by the author)

JAXの基礎を理解したところで、環境をベクトル化することでどのように大幅な速度向上を実現するかを見ていきましょう。

ベクトル化された環境:

まずは、ベクトル化環境とは何か、およびベクトル化が解決する問題は何かについて見ていきましょう。

ほとんどの場合、RLの実験はCPUとGPU間のデータ転送によって遅くなることがあります。Proximal Policy Optimization(PPO)などのDeep Learning RLアルゴリズムでは、ポリシーを近似するためにニューラルネットワークを使用します。

Deep Learningではいつものように、ニューラルネットワークはトレーニングおよび推論時にGPUを使用します。しかし、ほとんどの場合、環境CPU上で実行されます(複数の環境が並行して使用されている場合でもです)。

これは、通常の強化学習のループ(ポリシー(ニューラルネットワーク)によるアクションの選択と環境からの観測および報酬の受け取り)において、GPUとCPUの間で頻繁なデータのやり取りが必要であり、パフォーマンスに影響を及ぼします。

さらに、PyTorchなどのフレームワークを「jitting」せずに使用すると、GPUがPythonからCPUからの観測および報酬を受け取るのを待たなければならず、いくらかのオーバーヘッドが発生するかもしれません。

Usual RL batched training setup in PyTorch (made by the author)

一方、JAXを使用すると、バッチ環境を容易にGPU上で実行し、GPU-CPUのデータ転送による摩擦を取り除くことができます。

さらに、JAXはコードをXLAにコンパイルするため、Pythonの効率の悪さの影響を受けにくくなります。

RL batched training setup in JAX (made by the author)

詳細や興味深い応用については、Chris Luのこのブログ記事を強くお勧めします。

環境、エージェント、およびポリシーの実装:

強化学習の実験のさまざまなパーツの実装を見てみましょう。次は、必要な基本的な機能の概要です:

Class methods required for a simple RL setup (made by the author)

環境

この実装は、Nikolaj Goodger氏がJAXでの環境の書き方について紹介した優れた記事に基づいています。

JAXでRL環境を作成する

1.25 Billion Step/SecでCartPoleを実行する方法

VoAGI.com

まず、環境とそのメソッドの高レベルな概要を見てみましょう。これは、JAXで環境を実装するための一般的な計画です:

クラスメソッドについて詳しく見てみましょう(リマインダーとして、関数名が「_」で始まる場合、それらはプライベートであり、クラスのスコープ外で呼び出すべきではありません):

  • _get_obs:このメソッドは環境の状態をエージェントの観測に変換します。部分的に観測可能な環境や確率的な環境の場合、状態に適用される処理関数はここに記述されます。
  • _reset:複数のエージェントを並行して実行するために、エピソードの完了時に各エージェントのリセットを行うメソッドが必要です。
  • _reset_if_done:このメソッドは各ステップで呼び出され、”done”フラグがTrueに設定されている場合に_resetをトリガーします。
  • reset:このメソッドは実験の開始時に各エージェントの初期状態と関連するランダムキーを取得するために呼び出されます
  • step:状態とアクションを受け取り、環境は観測(新しい状態)、報酬、および更新された”done”フラグを返します。

実際には、GridWorld環境の一般的な実装は次のようになります:

前述のように、すべてのクラスメソッドは関数型プログラミングパラダイムに従っています。実際には、クラスインスタンスの内部状態を更新することはありません。さらに、クラス属性はすべてインスタンス化後に変更されない定数です。

詳細を見てみましょう:

  • __init__:GridWorldの文脈では、利用可能なアクションは[0, 1, 2, 3]です。これらのアクションは、self.movementsを使用して2次元配列に変換され、step関数で状態に追加されます。
  • _get_obs:環境は決定論的かつ完全に観測可能なので、エージェントは加工された観測ではなく、直接状態を受け取ります。
  • _reset_if_done:引数のenv_stateは、(state, key)のタプルであり、keyはjax.random.PRNGKeyです。この関数は、doneフラグがTrueに設定されている場合に初期状態を返すだけですが、JAX jitted関数内では従来のPythonの制御フローを使用することはできません。jax.lax.condを使用することで、次の式と等価な式を得ることができます:
def cond(condition, true_fun, false_fun, operand):  if condition: # doneフラグ == True    return true_fun(operand)  # self._reset(key)を返す  else:    return false_fun(operand) # env_stateを返す
  • step:アクションを移動に変換し、現在の状態に追加します(jax.numpy.clipにより、エージェントがグリッド内にとどまることが保証されます)。その後、環境がリセットする必要があるかどうかをチェックする前に、env_stateのタプルを更新します。step関数はトレーニング中に頻繁に使用されるため、jit化することでかなりのパフォーマンス向上が期待できます。@partial(jit, static_argnums=(0, )デコレータは、クラスメソッドの「self」引数が静的と見なされるべきであることを示しています。言い換えれば、クラスのプロパティは定数であり、step関数への連続した呼び出し中に変更されることはありません。

Q学習エージェント

Q学習エージェントは、更新関数、静的な学習率、および割引率で定義されます。

再度、update関数をjit化する際には、”self”引数を静的に渡します。また、q_values行列はset()を使用して直接修正され、クラス属性としては保存されません。

ε-グリーディポリシー

最後に、この実験で使用されるポリシーは標準的なε-グリーディポリシーです。重要なポイントは、ポリシーがランダムなタイブレークを使用することです。つまり、最大のQ値が一意でない場合、アクションは最大のQ値から一様にサンプリングされます(argmaxを使用すると、常に最初の最大のQ値を持つアクションが返されます)。これは、Q値がゼロの行列として初期化された場合、アクション0(右に移動)が常に選択されるという点で特に重要です。

それ以外の場合、ポリシーは次のように要約できます:

action = lax.cond(            explore, # p < epsilonの場合            _random_action_fn, # キーからランダムなアクションを選択            _greedy_action_fn, # Q値に基づいてグリーディなアクションを選択            operand=subkey, # 上記の関数の引数としてsubkeyを使用        )return action, subkey

keyをJAXで使用する場合(ここではランダムな浮動小数点数をサンプリングし、random.choiceを使用した)、キーを分割するのは一般的な方法です(つまり、「新しいランダムな状態に移る」、詳細はこちらを参照)。

単一エージェントのトレーニングループ:

必要なすべてのコンポーネントが揃ったので、単一エージェントをトレーニングしましょう。

以下は、Pythonicなトレーニングループです。ポリシーを使用してアクションを選択し、環境でステップを実行し、Q値を更新します。エピソードの終わりまで繰り返します。次に、N個のエピソードに対して同様のプロセスを繰り返します。数分後に見るように、エージェントをトレーニングするこの方法はかなり効率的ではありませんが、アルゴリズムの主要なステップを読みやすい方法でまとめています。

1つのCPU上で、1秒間に881エピソードと21,680ステップのペースで、10,000エピソードを11秒で終了します。

100%|██████████| 10000/10000 [00:11<00:00, 881.86it/s]総ステップ数: 238 488秒間のステップ数: 21 680

それでは、同じトレーニングループをJAX構文を使用して再現しましょう。以下は、ロールアウト関数の高レベルな説明です:

トレーニングロールアウト関数のJAX構文使用(著者制作)

要約すると、ロールアウト関数は以下のような動作をします:

  1. 観測報酬、および完了フラグを、時間ステップの数と同じ次元の空の配列として初期化します(jax.numpy.zerosを使用)。Q値は、形状が[時間ステップ+1, グリッドの次元x, グリッドの次元y, n_actions]の空の行列として初期化されます。
  2. env.reset()関数を呼び出して初期状態を取得します。
  3. jax.lax.fori_loop()関数を使用して、fori_body()関数をN回呼び出します。ここで、Nは時間ステップのパラメータです。
  4. fori_body()関数は、前回のPythonループと同様の振る舞いをします。アクションの選択、ステップの実行、およびQのアップデートの計算の後、obs、報酬、完了、およびq_values配列を直接更新します(Qのアップデートは時間ステップt+1を対象としています)。

この追加の複雑さにより、85倍の高速化が実現され、エージェントはおおよそ1,830,000ステップ/秒でトレーニングされます。環境が単純なため、トレーニングは単一のCPU上で行われます。

ただし、エンドツーエンドのベクトル化はよりスケーラブルであり、複雑な環境複数のGPUを利用するアルゴリズムに適用されるとさらに優れた効果があります(Chris Lu’s articleでは、CleanRL PyTorchを使用したPPOの実装とJAXの再現版との間で4,000倍の高速化が報告されています)。

100%|██████████| 1000000/1000000 [00:00<00:00, 1837563.94it/s]総ステップ数: 1,000,000秒間のステップ数: 1,837,563

エージェントのトレーニング後、GridWorldの各セル(つまり、状態)の最大Q値をプロットし、それが初期状態(右下隅)から目的地(左上隅)に効果的に移動することが確認されました。

GridWorldの各セルの最大Q値のヒートマップ表現(著者制作)

並列エージェントトレーニングループ:

約束通り、単一のエージェントをトレーニングするために必要な関数を書いたので、バッチ処理された環境で複数のエージェント並列にトレーニングするためにはほとんど作業が残っていません!

vmapのおかげで、前の関数をデータのバッチで動作するようにすばやく変換できます。ただし、期待される入力および出力の形状を指定する必要があります。たとえば、env.step:

  • in_axes = ((0,0), 0)は、入力の形状を表します。つまり、env_stateのタプル(次元(0、0))と観測(次元0)から構成されます。
  • out_axes = ((0, 0), 0, 0, 0)は、出力の形状を表します。出力は((env_state), obs, reward, done)です。
  • これで、env_statesとactionsの配列に対してv_stepを呼び出し、処理されたenv_states、観測、報酬、および完了フラグの配列を受け取ることができます。
  • パフォーマンスのために、バッチ処理された関数もすべて化します(環境をreset()するのはトレーニング関数で一度だけ呼び出されるため、jitを使う必要はないかもしれません)。

最後の調整は、各エージェントのデータを考慮するために、私たちの配列にバッチの次元を追加することです。

これにより、シングルエージェントの関数と比較して、最小の調整で複数のエージェントを並列に訓練するための関数を得ることができます。

このバージョンのトレーニング関数では、同様のパフォーマンスが得られます。

100%|██████████| 100000/100000 [00:02<00:00, 49036.11it/s]総ステップ数:100,000 * 30 = 3,000,000秒当たりのステップ数:49,036 * 30 = 1,471,080

以上です!ここまでお読みいただき、ベクトル化環境の実装に関する役立つ紹介を提供できたことを願っています。

お楽しみいただけた場合は、この記事を共有し、私のGitHubリポジトリにスターをつけていただけると幸いです。ご支援いただきありがとうございます!🙏

GitHub – RPegoud/jax_rl: JAX強化学習アルゴリズムとベクトル化環境の実装

JAX強化学習アルゴリズムとベクトル化環境の実装 – GitHub – RPegoud/jax_rl: JAX強化学習アルゴリズムとベクトル化環境の実装

github.com

最後に、もう少し掘り下げたい方のために、JAXを始めるのに役立つ有用なリソースのリストをご紹介します。

厳選された素晴らしいJAXの記事とリソースのリスト:

[1] Coderized(関数型プログラミング) 純粋なコーディングスタイルであり、バグはほぼ不可能です、YouTube

[2] Aleksa Gordić、JAX From Zero to Hero YouTube Playlist(2022年)、The AI Epiphany

[3] Nikolaj Goodger、Writing an RL Environment in JAX(2021年)

[4] Chris Lu、Achieving 4000x Speedups and Meta-Evolving Discoveries with PureJaxRL(2023年)、University of OxfordFoerster Lab for AI Research

[5] Nicholas Vadivelu、Awesome-JAX(2020年)、JAXのライブラリ、プロジェクト、リソースのリスト

[6] JAX公式ドキュメント、Training a Simple Neural Network, with PyTorch Data Loading

We will continue to update VoAGI; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more