「JAXとHaikuを使用してゼロからTransformerエンコーダを実装する🤖」

「JAXとHaikuを使ってゼロからTransformerエンコーダを実装する方法🤖」

トランスフォーマーの基本的な構成要素の理解

Transformers, in the style of Edward Hopper (generated by Dall.E 3)

2017年に「Attention is all you need」[0]という画期的な論文で紹介されたトランスフォーマーアーキテクチャは、最近の深層学習の歴史で最も影響力のあるブレイクスルーの一つであり、大規模言語モデルの台頭やコンピュータビジョンの分野でも利用されています。

従来の状態-of-the-artアーキテクチャである「再帰」に頼るアーキテクチャ(Long Short-Term Memory(LSTM)ネットワークやGated Recurrent Units(GRU)など)を成功裏に引き継ぎ、トランスフォーマーは「セルフアテンション」という概念を導入し、エンコーダ/デコーダアーキテクチャと組み合わせました。

本記事では、トランスフォーマーの半分である「エンコーダ」をゼロからステップバイステップで実装します。主なフレームワークにはDeepMindのディープラーニングライブラリである「Haiku」と、「JAX」を使用します。

JAXに慣れていない場合やその素晴らしい機能についてのリマインダーが必要な場合は、以前の記事で強化学習のコンテキストで取り上げています:

JAXを使用してRL環境をベクトル化および並列化し、光速でQ学習を行う

GridWorld環境をベクトル化し、CPU上で30個のQ学習エージェントを並列にトレーニングする方法について学ぶ

towardsdatascience.com

この記事では、エンコーダを構成する各ブロックを詳細に説明し、効率的な実装方法を学びます。特に、この記事の概要は次のとおりです:

  • 埋め込みレイヤーと位置エンコーディング
  • マルチヘッドアテンション
  • 残余接続とレイヤー正規化
  • 位置ごとのフォワードネットワーク

免責事項:この記事は完全な概念の紹介を意図したものではなく、まずは実装に焦点を当てています。必要な場合は、この投稿の最後にあるリソースに参照してください。

いつものように、この記事に対する完全にコメントされたコードやイラスト付きのノートブックは、「GitHub」で利用できます。

GitHub — RPegoud/jab: JAXで実装された基礎的なディープラーニングモデルのコレクション

JAXで実装された基礎的なディープラーニングモデルのコレクション — GitHub — RPegoud/jab: JAXで実装された基礎的なディープラーニングモデルのコレクション

github.com

主要なパラメータ

始める前に、エンコーダブロックで重要な役割を果たすいくつかのパラメータを定義する必要があります:

  • シーケンスの長さ(seq_len):シーケンス内のトークンまたは単語の数。
  • 埋め込み次元(embed_dim):埋め込みの次元、つまり単一のトークンまたは単語を記述するために使用される数値の数。
  • バッチサイズ(batch_size):同時に処理されるシーケンスの数、つまり入力のバッチのサイズ。

エンコーダーモデルへの入力シーケンスは、通常、形状が(batch_size, seq_len)になります。この記事では、batch_size=32seq_len=10を使用して、エンコーダーが同時に10語のシーケンスを32個処理することを意味します。

処理の各ステップでデータの形状に注意を払うことで、データがエンコーダーブロック内でどのようにフローするかをより視覚化し理解することができます。以下はエンコーダーの概要です。まずはじめに埋め込み層位置エンコーディングから始めます:

Transformerエンコーダーブロックの表現(著者作)

埋め込み層と位置エンコーディング

先に述べたように、モデルはトークンのバッチ化されたシーケンスを入力とします。これらのトークンの生成は、データセット内の一意の単語のセットを収集し、それぞれにインデックスを割り当てるだけの簡単なものかもしれません。その後、32個の10単語のシーケンスをサンプリングし、語彙中の各単語をそのインデックスに置き換えます。この手順により、期待される形状の(batch_size, seq_len)の配列が得られます。

それではエンコーダーの準備が整いました。最初のステップは、シーケンスのために「位置埋め込み」を作成することです。位置埋め込みは、単語埋め込み位置エンコーディング合計です。

単語埋め込み

単語埋め込みは、語彙中の単語間の意味意味的な関係をエンコードすることができます。この記事では、埋め込みの次元は64と固定されています。つまり、各単語は64次元のベクトルで表されており、類似の意味を持つ単語は似たような座標を持ちます。さらに、これらのベクトルを操作することで、下記のように単語間の関係を抽出することができます。

単語埋め込みによる類推の例(画像:developers.google.comから)

Haikuを使用して、学習可能な埋め込みを生成することは、次のように簡単です:

hk.Embed(vocab_size, embed_dim)

これらの埋め込みは、モデルのトレーニング中に他の学習可能なパラメータとともに更新されます(後ほど詳しく説明します)。

位置エンコーディング

再帰ニューラルネットワークとは異なり、Transformerは共有された隠れ状態に基づいてトークンの位置を推測することができません。そのため、位置エンコーディング、トークンの位置を伝えるベクトルを導入します。

基本的に、各トークンには交互にsinとcosの値が入った位置ベクトルが割り当てられます。これらのベクトルは単語埋め込みの次元数に一致させることで、両方を合計することができます。

具体的には、元のTransformer論文では以下の関数が使用されています:

位置エンコーディング関数(「Attention is all you need」、Vaswani et al. 2017より転載)

以下の図は、位置エンコーディングの機能をさらに理解するために役立ちます。最上段のプロットの最初の行を見てみましょう。そこには0と1の交互の列が見られます。実際、行はシーケンス内のトークンの位置(pos変数)を表し、列は埋め込み次元(i変数)を表します。

したがって、pos=0の場合、前の方程式は、偶数の埋め込み次元に対してsin(0)=0を返し、奇数の次元に対してcos(0)=1を返します。

また、隣接する行は似たような値を共有していることがわかりますが、最初と最後の行は大きく異なります。この特性は、モデルがシーケンス内の単語の距離順序を評価するのに役立ちます。

最後に、三番目のプロットは位置エンコーディングと埋め込みの合計を表し、これが埋め込みブロックの出力です。

単語の埋め込みと位置エンコーディングの表現、seq_len=16、embed_dim=64(著者作成)

Haikuを使用して、埋め込みレイヤーを次のように定義します。他のディープラーニングフレームワークと同様に、Haikuではカスタムモジュール(ここではhk.Module)を定義して学習可能なパラメータを保持し、モデルのコンポーネントの振る舞いを定義することができます。

各Haikuモジュールは__init__関数と__call__関数を持つ必要があります。ここでは、__call__関数は単純にhk.Embed関数と位置エンコーディングを使用して埋め込みを計算し、それらを合算します。

位置エンコーディング関数は、パフォーマンスのためにvmaplax.condなどのJAXの機能を使用しています。これらの関数について詳しく知りたくない場合は、私の以前の投稿をチェックしてみてください。

簡単に言えば、vmap単一のサンプル用の関数を定義し、ベクトル化してデータのバッチに適用できるようにする機能です。 in_axesパラメータは、dim入力の最初の軸を繰り返すことを指定するために使用されます(それは埋め込み次元です)。一方、lax.condはPythonのif/else文のXLA互換バージョンです。

セルフアテンションとマルチヘッドアテンション

アテンションは、入力単語に関連して、シーケンス内の各単語の重要性を計算することを目指しています。たとえば、次の文で:

「黒い猫がソファにジャンプして横になって眠りについたのは、疲れていたからです」

単語「それ」はモデルにとってかなり曖昧かもしれませんが、「」と「ソファ」の両方を指す可能性があります。訓練が十分にされたアテンションモデルは、「それ」が「」を指し、従って文の残りの部分に対して適切な注意値を割り当てることができるでしょう。

基本的に、アテンション値は、入力単語の文脈に基づくある単語の重要性を示す重みと見なすことができます。たとえば、「ジャンプ」という単語のアテンションベクトルは、「」(何がジャンプしたか?)、 「~に」、および「ソファ」(どこにジャンプしたか?)のような単語に対して高い値を持つでしょう。なぜなら、これらの単語はその文脈において関連性があるからです。

アテンションベクトルの視覚化(著者作成)

Transformerの論文では、「Scaled Dot-Product Attention」を使用して注意力が計算されます。これは以下の式で示されます:

Scaled Dot-Product Attention (reproduced from “Attention is all you need”, Vaswani et al. 2017)

ここで、Q、K、Vは「クエリ(Queries)」、「キー(Keys)」、「値(Values)」を表しています。これらの行列は、学習した重みベクトルWQ、WK、WVと位置埋め込みとの積を取ることによって得られます。

これらの名前は、情報がどのように処理され、注意ブロックで重み付けされるかを理解するのに役立つ「抽象的な概念(abstractions)」です。これは「検索システム」の語彙による比喩です(たとえば、YouTubeでビデオを検索するなど)。

以下は「直感的な(intuitive)」な説明です:

  • クエリ(Queries):シーケンスのすべての位置についての「質問セット」として解釈することができます。たとえば、単語の文脈を尋ね、シーケンスの最も関連のある部分を特定しようとします。
  • キー(Keys):クエリが対話する情報を保持していると考えることができ、クエリとキーの互換性がクエリが対応する値にどれだけ注意を払うかを決定します。
  • 値(Values):キーとクエリの一致によって、関連するキーがどれだけ重要かを判断し、値はキーとペアになった実際のコンテンツです。

以下の図では、クエリがYouTubeの検索であり、キーがビデオの説明とメタデータであり、値は関連するビデオです。

Intuitive representation of the Queries, Keys, Values concept (made by the author)

この場合、クエリ、キー、値は「同じソース」から来ます(入力シーケンスから派生しているため)、したがって「自己注意(self-attention)」という名前が付けられています。

注意スコアの計算は通常「複数回並列に行われます」。各回の計算では、埋め込みの一部(fraction of the embeddings)が使用されます。この仕組みを「Multi-Head Attention」と呼び、各ヘッドがデータのいくつかの異なる表現を並列に学習できるようにすることで、より「堅牢な(robust)」モデルを実現します。

単一の注意ヘッドは、一般的に形状が(batch_size、seq_len、d_k)の配列を処理します。ここで、d_kはヘッドの数と埋め込みの次元の比率(d_k = n_heads/embed_dim)として設定されます。このように、各ヘッドの出力を適切に連結することで、形状が(batch_size、seq_len、embed_dim)の配列が得られます。

注意行列の計算は、いくつかのステップに分けることができます:

  • まず、学習可能な重みベクトルWQ、WK、WVを定義します。これらのベクトルの形状は(n_heads、embed_dim、d_k)です。
  • 同時に、位置埋め込みを重みベクトルと「乗算」します。これにより、形状が(batch_size、seq_len、d_k)のQ、K、V行列が得られます。
  • 次に、QとK(転置)の「ドット積」をスケーリングします。このスケーリングでは、ドット積の結果をd_kの平方根で割り、行列の行に対してsoftmax関数を適用します。したがって、入力トークン(つまり、行)の注意スコアは1に合計されるため、値が大きくなりすぎて計算が遅くなるのを防ぎます。出力の形状は(batch_size、seq_len、seq_len)です。
  • 最後に、前の操作の結果にVをドット積します。これにより、出力の形状が(batch_size、seq_len、d_k)になります。
作者によるアテンションブロック内の行列演算の視覚表現
  • 各アテンションヘッドの出力は、(batch_size, seq_len, embed_dim)の形状の行列を形成するために連結されることがあります。Transformer論文では、マルチヘッドアテンションモジュールの最後に線形層も追加され、すべてのアテンションヘッドから学習された表現を集約および結合します。
マルチヘッドアテンション行列と線形層の連結(作者作成)

Haikuでは、マルチヘッドアテンションモジュールは以下のように実装されます。 __call__関数は、上記のグラフと同じロジックに従い、クラスのメソッドはvmap(異なるアテンションヘッドと行列上での操作をベクトル化するため)およびtree_map(行列のドット積を重みベクトルにマップするため)など、JAXのユーティリティを活用しています。

残差接続とレイヤー正規化

Transformerグラフでお気づきかもしれませんが、マルチヘッドアテンションブロックとフィードフォワードネットの後には残差接続レイヤー正規化があります。

残余接続またはスキップ接続

残余接続は、勾配がモデルのパラメータを効果的に更新するのに十分に小さくなるときに発生する勾配消失問題を解決するための標準的な解決策です。

このような問題は、特に深いアーキテクチャで自然に生じるため、残余接続は、コンピュータビジョンのResNet(Kaiming et al、2015)、強化学習のAlphaZero(Silver et al、2017)などの複雑なモデル、そしてもちろんTransformersにも使用されます。

実際には、残余接続は、特定の層の出力を次の層に直接転送し、その過程で1つ以上の層をスキップします。たとえば、マルチヘッドアテンションの周りの残余接続は、マルチヘッドアテンションの出力と位置埋め込みの総和と同等です。

これにより、逆伝播中に勾配がアーキテクチャをより効率的に流れるようになり、より迅速な収束とより安定したトレーニングにつながることが通常です。

Transformersの残余接続の表現(作者作成)

レイヤー正規化

レイヤー正規化は、アテンションブロックのような場所で発生する可能性のある複数の行列が各フォワードパスで乗算されるため、「爆発する」(無限に近づく)値がモデルを通じて伝播されないようにするのに役立ちます。

バッチ正規化はバッチの次元を横断して正規化し、一様分布を仮定しますが、レイヤー正規化では特徴線上で操作が行われます。このアプローチは、各文が異なる分布を持つ場合があり、意味や語彙の異なるバッチに適しています。

埋め込みアテンション値などの特徴を横断して正規化することで、レイヤー正規化はデータを一様なスケールに標準化し、それぞれの異なる文の特徴を混同せずに維持します。

トランスフォーマーの文脈でのレイヤーノーマライゼーションの表現(著者作成)

レイヤーノーマライゼーションの実装は非常に簡単で、学習可能なパラメータαとβを初期化し、望ましい特徴軸に沿って正規化を行います。

位置ごとのフィードフォワードネットワーク

エンコーダーの最後のコンポーネントである位置ごとのフィードフォワードネットワークを説明します。この完全に接続されたネットワークは、アテンションブロックの正規化された出力を入力として受け取り、非線形性を導入し、モデルの容量を高め、複雑な関数を学習するために使用されます。

これはgeluアクティベーションによって区切られた2つの密な層で構成されています:

このブロックの後、エンコーダーを完成させるために、もう一つの残余接続とレイヤーノーマライゼーションがあります。

まとめ

以上です!今頃には、トランスフォーマーのエンコーダーの主要な概念について理解しているはずです。以下は、Haikuでは各層に名前を割り当てることによって、学習可能なパラメータを分離して簡単にアクセスできるようにするためのエンコーダーの完全なクラスです。 __call__関数は、エンコーダーの異なるステップの良いまとめを提供しています:

実際のデータでこのモジュールを使用するには、エンコーダークラスをカプセル化した関数にhk.transformを適用する必要があります。実際に、JAXは関数型プログラミングのパラダイムを採用しているため、Haikuも同じ原則に従います。

エンコーダークラスのインスタンスを含む関数を定義し、フォワードパスの出力を返します。 hk.transformを適用すると、initapplyの2つの関数にアクセスできる変換されたオブジェクトが返されます。

前者は、ランダムなキーといくつかのダミーデータ(ここでは、形状がbatch_size、seq_lenのゼロの配列を渡していることに注意してください)でモジュールを初期化することを可能にし、後者は実際のデータを処理することを可能にします。

# 注:以下の2つの構文は等価です。  # 1:transformをクラスデコレータとして使用@hk.transformdef encoder(x):  ...  return model(x)  encoder.init(...)encoder.apply(...)# 2:別々に変換を適用def encoder(x):  ...  return model(x)encoder_fn = hk.transform(encoder)encoder_fn.init(...)encoder_fn.apply(...)

次の記事では、トランスフォーマーを完成させるために、これまでに紹介したほとんどのブロックを再利用するデコーダーを追加する方法と、Optaxを使用して特定のタスクでモデルを訓練する方法を学びます!

結論

ここまで読んでいただき、ありがとうございます。もしご興味があれば、コード付きで完全にコメントされたGitHub上でそれを見つけることができます。さらに、おもちゃのデータセットを使用した詳細な手順とウォークスルーもあります。

GitHub – RPegoud/jab: JAXで実装された基礎的なディープラーニングモデルのコレクション

JAXで実装された基礎的なディープラーニングモデルのコレクション – GitHub – RPegoud/jab

github.com

もしトランスフォーマーについて深く探求したい場合は、以下のセクションには私がこの記事を作成するのに役立った一部の記事があります。

次回まで👋

参考文献とリソース:

[1] Attention is all you need(2017年)、Vaswani et al、Google

[2] Attentionメカニズムのキー、クエリ、値は実際には何ですか?(2019年)Stack Exchange

[3] イラスト化されたトランスフォーマー(2018年)、ジェイ・アラマー

[4] トランスフォーマーモデルの位置符号化への優しい紹介(2023年)、メフリン・サエード、Machine Learning Mastery

画像のクレジット

We will continue to update VoAGI; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more