Deep learning論文の数学をPyTorchで効率的に実装する:SimCLR コントラスティブロス

SimCLRコントラスティブロス:Deep Learning論文の数学をPyTorchで効率的に実装する

高度な数学の公式をパフォーマンスの良いPyTorchコードに実装する方法を学ぶ。

Jeswin Thomasによる写真、Unsplash

イントロダクション

ディープラーニングモデルや損失関数の数学の理解を深め、PyTorchのスキルを向上させるための最良の方法の一つは、自分自身でディープラーニングの論文を実装することに慣れることです。

本やブログの投稿は、コーディングを始め、ML / DLの基礎を学ぶのに役立つかもしれませんが、いくつかを学び、フィールドの日常的なタスクに慣れると、学習の旅は自分自身で進む必要があり、ほとんどのオンラインリソースは退屈で浅すぎると感じるかもしれません。しかし、新しいディープラーニングの論文を公開されると同時に学び、それに必要な数学の要素(著者の理論の数学的証明ではなく)を理解し、それを効率的なコードに実装できる能力のあるコーダーであれば、フィールドの最新情報を把握し、新しいアイデアを学ぶことは誰にも止められません。

コントラスティブロスの実装

私の日常と私が数学を実装するために従う手順を紹介します。それは簡単ではない例である「SimCLR論文」の「コントラスティブロス」です。

ここに損失の数学的定式化があります:

SimCLR論文からのコントラスティブ(NT-Xent)ロス | https://arxiv.org/pdf/2002.05709.pdfより

この式の見た目は驚くほどです!おそらくGitHubにはたくさんのPyTorchの実装があると思いますので、それらを使用しましょう:) そして、はい、正解です。オンラインには数十の実装があります。ただし、このスキルを練習するための良い例であり、良い出発点となると思います。

コード中で数学を実装する手順

論文中の数学を効率的なPyTorchコードに実装するための私の手順は次のとおりです:

  1. 数学を理解し、簡単な言葉で説明する
  2. 単純なPythonの「for」ループを使用して初期バージョンを実装する(今のところ高度な行列の乗算は必要ありません)
  3. コードを効率的で行列に対応したPyTorchコードに変換する

OK、まず最初のステップに進みましょう。

ステップ1:数学を理解し、簡単な言葉で説明する

線形代数の基礎知識があり、数学の記法に精通していることを前提としています。もしそうでない場合は、このツールを使って各記号が何であり、数学で何をするかを描くことで知ることができます。また、ほとんどの記号が説明されている素晴らしいWikipediaのページもご覧いただけます。これらは新しいことを学ぶ機会であり、必要な時に必要なものを検索して読むことで学ぶよりも効率的な方法だと私は信じています。数日後に数学の教科書から始めて、数日後にそれをやめるという方法よりもです:)

ビジネスに戻りましょう。数式の上の段落がより多くの文脈を追加しており、SimCLRの学習戦略では、N枚の画像から始めて、それぞれを2回変換してこれらの画像の拡張ビューを取得します(2 * N枚の画像)。次に、これらの2 * N枚の画像をモデルに渡してそれぞれの埋め込みベクトルを取得します。そして、同じ画像の2つの拡張ビュー(正のペア)の埋め込みベクトルを埋め込み空間でより近づけたいと思います(他のすべての正のペアにも同様の操作を行います)。2つのベクトルがどれだけ似ているか(近い、同じ方向にある)を測定する方法の一つは、コサイン類似度を使用することです。これはsim(u, v)と定義されています(上の画像で定義を調べてください)。

単純に言えば、この式が説明していることは、バッチ内の各アイテムについて、画像の拡張ビューの埋め込みである(注:バッチにはさまざまな画像の拡張ビューの埋め込みが含まれている→N枚の画像で始める場合、バッチのサイズは2*Nになる)ことです。最初に、その画像の他の拡張ビューの埋め込みを見つけて、正のペアを作ります。次に、これらの2つの埋め込みのコサイン類似度を計算し、それを指数関数化します(式の分子)。次に、最初に始めた最初の埋め込みベクトルと他のすべてのペアのコサイン類似度を指数関数化し(ただし、それ自体のペアを除く、これが式中の1[k!=i]の意味です)、それらを合計して分母を作ります。これで、分子を分母で割り、その自然対数を取り、符号を反転させたものが、バッチ内の最初のアイテムの損失となります。バッチ内の他のすべてのアイテムについて同じプロセスを繰り返し、平均を取り、PyTorchの.backward()メソッドを呼び出して勾配を計算できるようにします。

ステップ2:単純なPythonコードを使用して実装する、単純な「for」ループを使用

遅い「for」ループを使用したシンプルなPython実装

コードを見てみましょう。2つの画像AとBがあるとします。変数aug_views_1は、これら2つの画像(A1とB1)の1つの拡張ビューの埋め込み(それぞれサイズ3)を保持しています。aug_views_2も同様に(A2とB2)。つまり、両方の行列の最初のアイテムは画像Aに関連し、両方の2番目のアイテムは画像Bに関連しています。2つの行列をプロジェクション行列に連結します(A1、B1、A2、B2の4つのベクトルが含まれています)。

プロジェクション行列内のベクトルの関係を保持するために、連結行列に関連する2つのアイテムを保存するpos_pairs辞書を定義します(まもなくF.normalize()のことを説明します!)。

コードの次の行で、forループを使用してプロジェクション行列のアイテムをループして、辞書を使用して関連するベクトルを見つけ、そのコサイン類似度を計算しています。なぜベクトルのサイズで割らないのかと思うかもしれませんが、コサイン類似度の式が示唆しているように、ループを開始する前に、F.normalize関数を使用してプロジェクション行列内のすべてのベクトルを正規化して、サイズが1になるようにしています。したがって、コサイン類似度を計算する行でサイズで割る必要はありません。

分子を構築した後、バッチ内の他のベクトルのインデックス(iと同じインデックスを除く)を見つけて、分母を構成するコサイン類似度を計算します。最後に、分子を分母で割り、対数関数を適用して符号を反転させて損失を計算します。各行で何が起こっているかを理解するために、コードを試してみてください。

ステップ3:効率的な行列操作に変換するPyTorchコードへの変換

前のPython実装の問題は、トレーニングパイプラインで使用するには遅すぎることです。遅い「for」ループを取り除き、並列化の力を活用するために、行列の乗算と配列操作に変換する必要があります。

PyTorch実装

このコードスニペットで何が起こっているかを見てみましょう。今回は、labels_1とlabels_2テンソルを導入して、これらの画像が属する任意のクラスをエンコードします。A1、A2、B1、B2の画像の関係をエンコードする方法が必要です。私が選んだように、ラベル0と1を選んでも、5と8と言っても問題ありません。

埋め込みとラベルを連結した後、すべての可能なペアのコサイン類似度を含むsim_matrixを作成します。

How the sim_matrix looks like: the green cells contain our positive pairs, the orange cells are the pairs which need to be ignored in the denominator | Visualization by the author

上記の可視化は、コードがどのように動作しているか、なぜその手順を行っているかを理解するために必要なすべてです。sim_matrixの最初の行を考慮すると、バッチの最初のアイテム(A1)の損失を次のように計算できます。A1A2(指数化)をA1B1、A1A2、およびA1B2(それぞれを最初に指数化)の合計で除算し、すべての損失を格納するテンソルの最初のアイテムに結果を保持する必要があります。したがって、まず、上記の可視化で緑色のセルを見つけるためのマスクを作成する必要があります。変数maskを定義するコードの2行は、まさにこれを行います。分子は、sim_matrixを作成したマスクで乗算し、各行のアイテムを合計します(マスキング後、各行には1つの非ゼロのアイテム、つまり緑色のセルしかありません)。分母を計算するために、対角線上のオレンジ色のセルを無視して、各行ごとに合計する必要があります。これを行うためには、PyTorchテンソルの.diag()メソッドを使用します。残りは自己説明的です!

ボーナス: AIアシスタント(ChatGPT、Copilotなど)を使用して式を実装する

私たちは、ディープラーニングの論文の数学を理解し、実装するために非常に優れたツールを手に入れています。たとえば、論文から式を与えた後、ChatGPT(または他の類似のツール)にPyTorchでコードを実装するように依頼することができます。私の経験では、ChatGPTは、Pythonicなforループの実装ステップに何とか自分自身を持ち込むことができれば、最も役に立ち、最良の最終回答を提供することができます。その素朴な実装をChatGPTに渡し、行列の乗算とテンソルの操作のみを使用する効率的なPyTorchコードに変換するように依頼してみてください。驚くことでしょう 🙂

さらなる読み物

以下の2つの素晴らしい実装をチェックすることをお勧めします。同じアイデアを考慮に入れて、この実装をより微妙な状況に拡張する方法を学ぶことができます。例えば、教師あり対比損失設定など。

  1. Guillaume ErhardによるSupervised Contrastive Loss
  2. Yonglong TianによるSupContrast

私について

私は、医学的イメージングアプリケーションにおける深層学習ソリューションの使用に焦点を当てた、機械学習開発者であり医学生です。私の研究は主に、さまざまな状況下での深層モデルの汎化性を調査することにあります。お気軽に私に連絡してください。メール、Twitter、またはLinkedInを介して。

We will continue to update VoAGI; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

機械学習

「RBIは、規制監督のためにAIを活用するために、マッキンゼーとアクセンチュアと提携します」

規制監督における重要な変化を示す動きとして、インド準備銀行(RBI)は、国際的なコンサルティング企業であるマッキンゼー・...

データサイエンス

「ChatGPTにおける適切なプロンプト設計の必須ガイド」

「Prompt Engineering」に没頭して、急速に成長しているChatGPTユーザーベースに与える影響に焦点を当てた詳細なガイドで、プ...

AIテクノロジー

「最も価値のあるコードは、書くべきでないコードです」

伝統的なプログラミング言語のコーディングスキルは、AIが進化するにつれてますます重要ではなくなります私はコーディングな...

人工知能

5分で作成するLow-Code GPT AIアプリを作成する

AIとデータベースの相互作用にAIのツール、AINIROとOpenAIのGPTを組み合わせることで、5分で完全なデータベースをCRUDアプリ...

機械学習

ビジネスにおけるAIの潜在的なリスクの理解と軽減

「この技術を導入する際に遭遇する可能性のあるAIのリスクを学びましょうビジネスオーナーとして、そのようなリスクを避ける...

機械学習

AIのダークサイドを明らかにする:プロンプトハッキングがあなたのAIシステムを妨害する方法

LLMsによるハッキングを防止し、データを保護するために、AIシステムを保護してくださいこの新興脅威に対するリスク、影響、...