PyTorchを使用した効率的な画像セグメンテーション:Part 4

Efficient Image Segmentation using PyTorch Part 4

ビジョン・トランスフォーマーに基づくモデル

この4部作では、PyTorchを使用して、深層学習技術を使用して画像セグメンテーションをゼロから段階的に実装します。この部分では、画像セグメンテーションのためにビジョン・トランスフォーマーに基づくモデルを実装することに焦点を当てます。

Naresh Singhと共著

Figure 1: Result of running image segmentation using a vision transformer model architecture. From top to bottom, input images, ground truth segmentation masks, and predicted segmentation masks. Source: Author(s)

記事の概要

この記事では、ディープラーニング界を席巻したTransformerアーキテクチャについて紹介します。Transformerは、言語、ビジョン、音声などの異なるモダリティをモデル化できるマルチモーダルなアーキテクチャです。

この記事では、以下を行います。

  1. Transformerアーキテクチャと関連するキー・コンセプトの学習
  2. ビジョン・トランスフォーマー・アーキテクチャの理解
  3. Scratchから書かれたビジョン・トランスフォーマー・モデルの紹介。すべての構成要素と動く部品を理解できるようにします。
  4. このモデルに入力テンソルを送り、形状がどのように変化するかを検査します。
  5. Oxford IIIT Petデータセットでこのモデルを使用して画像セグメンテーションを実行します。
  6. このセグメンテーション・タスクの結果を観察します。
  7. セマンティック・セグメンテーションのための最先端のビジョン・トランスフォーマーであるSegFormerを簡単に紹介します。

この記事では、モデルトレーニングのコードと結果を参照します。結果を再現する場合は、最初のノートブックが合理的な時間内に完了するようにGPUが必要です。

このシリーズの記事

このシリーズは、深層学習における実践、ビジョンAI、堅実な理論と実践経験について学びたいすべての経験レベルの読者を対象としています。これは、以下の記事で構成される4部作になる予定です。

  1. コンセプトとアイデア
  2. CNNベースのモデル
  3. Depthwise separable convolutions
  4. ビジョン・トランスフォーマーに基づくモデル(この記事)

ビジョン・トランスフォーマーへの旅を始めましょう。Transformerアーキテクチャの紹介と直感的な理解から始めましょう。

Transformerアーキテクチャ

Transformerアーキテクチャは、通信と計算の交互レイヤーの構成として考えることができます。このアイデアは、図2に視覚的に表示されます。TransformerにはN個の処理ユニット(図2ではN=3)があり、それぞれは入力の1/Nを処理する責任があります。それらの処理ユニットが意味のある結果を生み出すためには、それぞれが入力のグローバルビューを持つ必要があります。したがって、システムは、すべての処理ユニットからすべての他の処理ユニットにデータに関する情報を繰り返し通信し(赤、緑、青の矢印によって示される)、その情報に基づいて計算が行われます。このプロセスを十分に繰り返すことで、モデルは望ましい結果を生成できます。

Figure 2: Interleaved communication and computation in transformers. The image shows just 2 layers of communication and computation. In practice, there are many more such layers. Source: Author(s).

オンラインリソースのほとんどは、通常、「Attention is all you need」という論文で提供されたエンコーダーとデコーダーの両方を説明しています。しかし、この記事では、Transformerのエンコーダー部分のみを説明します。

Transformerにおける通信と計算の構成要素を詳しく見てみましょう。

Transformerにおける通信:Attention

Transformerにおいて、通信は、アテンション・レイヤーとして実装されます。PyTorchでは、これはMultiHeadAttentionと呼ばれます。その名前の理由については、少し後で説明します。

ドキュメンテーションには以下のように記載されています:

「論文で説明されているように、異なる表現サブスペースから情報に共同で注意を払うことができます:Attention is all you need。」

アテンション機構は、形状(バッチ、長さ、特徴)の入力テンソルxを消費し、同様の形状のテンソルyを生成します。テンソルは、同じインスタンス内の他の入力に注目してそれらの特徴を更新することに基づいています。したがって、サイズ「長さ」の各テンソルの長さ「特徴」の特徴は、他のすべてのテンソルに基づいて更新されます。これがアテンション機構の二次コストが発生する場所です。

Figure 3: Attention of the word “it” shown relative to the other words in the sentence. We can see that “it” is paying attention to the words “animal”, “too”, and “tire(d)” in the same sentence. Source: Generated using this colab .

ビジョントランスフォーマーの文脈では、トランスフォーマーへの入力は画像です。これを128 x 128(幅、高さ)の画像とします。これを16 x 16の複数の小さなパッチに分割します。128 x 128の画像の場合、64のパッチ(長さ)があり、各行に8つのパッチ、8行のパッチがあります。

これら64の16 x 16ピクセルの各パッチは、トランスフォーマーモデルの別々の入力として考慮されます。詳細には踏み込まずに、これを64の異なる処理ユニットによって駆動されるプロセスと考えることが十分です。各処理ユニットのアテンション機構は、担当する画像パッチを見て、同じ画像内の他の63の処理ユニットすべてにクエリを送信し、自分自身の画像パッチを効果的に処理するために役立つ可能性のある情報を要求します。

アテンションによる通信ステップの後には、次に計算が行われます。

トランスフォーマーの計算:マルチレイヤーパーセプトロン

トランスフォーマーの計算は、MultiLayerPerceptron(MLP)ユニットにすぎません。このユニットは2つの線形層で構成され、途中にGeLU非線形性があります。他の非線形性も考慮できます。このユニットは、最初に入力を4倍のサイズに投影し、1倍に再投影します。これは入力サイズと同じです。

私たちのノートブックで見るコードでは、このクラスはMultiLayerPerceptronと呼ばれます。コードは以下のようになります。

class MultiLayerPerceptron(nn.Sequential):    def __init__(self, embed_size, dropout):        super().__init__(            nn.Linear(embed_size, embed_size * 4),            nn.GELU(),            nn.Linear(embed_size * 4, embed_size),            nn.Dropout(p=dropout),        )    # end def# end class

トランスフォーマーアーキテクチャの高レベルな動作を理解したので、私たちは画像セグメンテーションを実行するためにビジョントランスフォーマーに注目することにしましょう。

ビジョントランスフォーマー

ビジョントランスフォーマーは、最初に「16×16の単語で表されるイメージ:規模の大きい画像認識のためのトランスフォーマー」というタイトルの論文で紹介されました。論文では、著者がバニラトランスフォーマーアーキテクチャを画像分類の問題に適用する方法を説明しています。これは、画像を16×16のパッチに分割し、各パッチをモデルへの入力トークンとして扱うことによって行われます。トランスフォーマーエンコーダーモデルはこれらの入力トークンを受け取り、入力画像に対してクラスを予測するように求められます。

Figure 4: Source: Transformers for image recognition at scale .

私たちの場合、画像セグメンテーションに興味があります。それは、ピクセルレベルの分類タスクと考えることができます。なぜなら、入力画像に対して1ピクセルあたりのターゲットクラスを予測することを意図しているからです。

私たちはバニラのビジョン・トランスフォーマーに小さながらも重要な変更を加え、分類用のMLPヘッドを画素レベル分類用のMLPヘッドに置き換えました。ビジョン・トランスフォーマーによって予測されたセグメンテーション・マスクを持つ各パッチについて、出力に単一の線形層があります。この共有線形層は、モデルに入力された各パッチのセグメンテーション・マスクを予測します。

ビジョン・トランスフォーマーの場合、サイズ16×16のパッチは、特定の時間ステップにおける単一の入力トークンに相当すると見なされます。

Figure 5: The end to end working of the vision transformer for image segmentation. Image generated using this notebook . Source: Author(s).

ビジョン・トランスフォーマーにおけるテンソル次元の直感を構築する

深層CNNを扱う際に、私たちが使用したテンソルの次元は、(N、C H、W)であり、文字は以下を表します。

  • N:バッチサイズ
  • C:チャネル数
  • H:高さ
  • W:幅

これは、画像に非常に特化した特徴を匂わせているため、2D画像処理に適していると言えます。

一方、トランスフォーマーを使用する場合、事柄ははるかに汎用的でドメインに依存しません。以下で説明することは、ビジョン、テキスト、NLP、オーディオまたはその他の問題に適用されます。入力データがシーケンスとして表現できる場合です。テンソルの表現を通じてビジョン・トランスフォーマーを流れるとき、ビジョン固有のバイアスはほとんどありません。

トランスフォーマーと注意を扱う際には、テンソルが次の形状を持つことを期待します:(B、T、C)。文字は以下を表します。

  • B:バッチサイズ(CNNと同じ)
  • T:時間次元またはシーケンスの長さ。この次元はLとも呼ばれます。ビジョン・トランスフォーマーの場合、各画像パッチはこの次元に対応します。16個の画像パッチがある場合、T次元の値は16になります。
  • C:チャネルまたは埋め込みサイズ次元。この次元はEとも呼ばれます。画像を処理する場合、サイズ3x16x16(チャネル、幅、高さ)の各パッチは、パッチ埋め込み層を介してサイズCの埋め込みにマップされます。後でこれがどのように行われるかを見ていきます。

次に、入力画像テンソルがどのように変異し、セグメンテーションマスクを予測するために処理されるかを見ていきましょう。

ビジョン・トランスフォーマー内のテンソルの旅

深層CNNでは、テンソルの旅は、UNet、SegNet、またはその他のCNNベースのアーキテクチャで次のようになります。

入力テンソルは通常、形状(1、3、128、128)です。このテンソルは、空間次元が縮小し、チャネル次元が通常2倍に増加する一連の畳み込みと最大プーリング操作を行います。これを特徴エンコーダーと呼びます。その後、逆操作を行い、空間次元を増加させ、チャネル次元を減少させます。これを特徴デコーダーと呼びます。デコードプロセスの後、テンソルの形状は(1、64、128、128)となります。これは、バイアスのない1×1の点の畳み込みを使用して、所望の出力チャネルCに(1、C、128、128)として投影されます。

Figure 6: Typical progression of tensor shapes through a deep CNN used for image segmentation. Source: Author(s).

ビジョン・トランスフォーマーでは、フローははるかに複雑です。以下の画像を見て、テンソルが各ステップでどのように形状を変化させるかを理解してみましょう。

Figure 7: Typical progression of tensor shapes through a vision transformer for image segmentation. Source: Author(s).

ビジョン・トランスフォーマーを流れるテンソルの形状がどのように更新されるか、各ステップを詳しく見てみましょう。これをより理解するために、テンソルの寸法に具体的な値を取ってみましょう。

  1. バッチ正規化:入力および出力テンソルの形状は(1, 3, 128, 128)です。形状は変わりませんが、値はゼロ平均および単位分散に正規化されます。
  2. 画像からパッチへ:形状が(1, 3, 128, 128)の入力テンソルは、16×16の画像を積み重ねたパッチに変換されます。出力テンソルの形状は(1, 64, 768)です。
  3. パッチ埋め込み:パッチ埋め込み層は、768の入力チャネルを512の埋め込みチャネルにマッピングします(この例では)。出力テンソルは形状が(1, 64, 512)です。パッチ埋め込み層は、基本的にはPyTorchのnn.Linear層です。
  4. 位置埋め込み:位置埋め込み層には入力テンソルはありませんが、パッチ埋め込みと同じ形状の学習可能なテンソル(PyTorchのtrainable tensor)を効果的に提供します。これは形状が(1, 64, 512)です。
  5. 加算:パッチと位置の埋め込みは、ピースワイズに加算され、ビジョン・トランスフォーマー・エンコーダの入力を生成します。このテンソルの形状は(1, 64, 512)です。ビジョン・トランスフォーマーの主要なワークホースであるエンコーダは、このテンソルの形状を変更せずに残します。
  6. トランスフォーマー・エンコーダ:形状が(1, 64, 512)の入力テンソルは、複数のトランスフォーマー・エンコーダ・ブロックを通過し、各ブロックには複数のアテンション・ヘッド(通信)が続き、MLP層(計算)が続きます。テンソルの形状は(1, 64, 512)のままです。
  7. 線形出力プロジェクション:16×16の各パッチを10チャネルにセグメンテーションしたいと仮定すると、出力プロジェクションのためのnn.Linear層は、512の埋め込みチャネルを16x16x10=2560の出力チャネルに変換します。このテンソルは(1, 64, 2560)のように見えます。上記の図ではC’ = 10です。理想的には、これは多層パーセプトロンであるべきです。なぜなら、「MLPは普遍的な関数近似器である」からですが、これは教育的な演習であるため、単一の線形層を使用しています。
  8. パッチから画像へ:このレイヤーは、(1, 64, 2560)のテンソルとしてエンコードされた64のパッチを、セグメンテーション・マスクのように見えるものに戻します。これは10個の単一チャネル画像、またはこの場合は1つの10チャネル画像であり、各チャネルは10クラスのうち1つのセグメンテーション・マスクです。出力テンソルの形状は(1, 10, 128, 128)です。

これで、ビジョン・トランスフォーマーを使用して入力画像をセグメンテーションすることができました!次に、実験と結果を見てみましょう。

アクション中のビジョン・トランスフォーマー

このノートブックには、このセクションのすべてのコードが含まれています。

コードとクラスの構造に関しては、上記のブロック図に密接に似ています。上記で触れたほとんどの概念は、このノートブック内のクラス名と1:1の対応関係があります。

トランスフォーマー内のアテンション層に関連するいくつかの概念は、モデルの重要なハイパーパラメータです。私たちは前述のマルチヘッドアテンションの詳細について何も言及していないため、この記事の目的には含まれていませんが、トランスフォーマー内のアテンションメカニズムの基本的な理解がない場合は、上記の参考資料を事前に読むことを強くお勧めします。

私たちは、セグメンテーションのためのビジョン・トランスフォーマーに以下のモデル・パラメータを使用しました。

  1. PatchEmbedding層の768埋め込み寸法
  2. 12トランスフォーマー・エンコーダ・ブロック
  3. トランスフォーマー・エンコーダ・ブロックごとに8つのアテンション・ヘッド
  4. マルチヘッドアテンションおよびMLPで20%のドロップアウト

この構成は、VisionTransformerArgs Pythonデータクラスで見ることができます。

@dataclassclass VisionTransformerArgs:    """Arguments to the VisionTransformerForSegmentation."""    image_size: int = 128    patch_size: int = 16    in_channels: int = 3    out_channels: int = 3    embed_size: int = 768    num_blocks: int = 12    num_heads: int = 8    dropout: float = 0.2# end class

モデルのトレーニングおよび検証中に以前と同様の構成が使用されました。構成は以下のように指定されています。

  1. 過学習を防ぐために、ランダムな水平反転と色のジッターのデータ拡張がトレーニングセットに適用されます。
  2. イメージは非アスペクト保存リサイズ操作で128×128ピクセルにリサイズされます。
  3. イメージに入力正規化は適用されません。代わりに、モデルの最初の層としてバッチ正規化レイヤーが使用されます。
  4. モデルは、Adamオプティマイザを使用して50エポックでトレーニングされ、LRは0.0004で、学習率を12エポックごとに0.8倍に減衰させるStepLRスケジューラーが使用されます。
  5. クロスエントロピー損失関数を使用して、ピクセルがペット、背景、またはペットの境界に属するかどうかを分類します。

モデルには86.28Mのパラメータがあり、50のトレーニングエポック後に85.89%の検証精度を達成しました。これは、20のトレーニングエポック後に深層CNNモデルが達成した88.28%の精度よりも低いです。これは、実験的に検証する必要があるいくつかの要因のためかもしれません。

  1. 最後の出力投影層は単一のnn.Linearであり、マルチレイヤーパーセプトロンではありません。
  2. 16×16パッチサイズは、より細かい粒度の詳細を捕捉するには大きすぎます。
  3. 十分なトレーニングエポックがありません。
  4. 十分なトレーニングデータがありません。トランスフォーマーモデルは、深層CNNモデルと比較して効果的にトレーニングするためには、より多くのデータが必要であることが知られています。
  5. 学習率が低すぎます。

私たちは、21枚の画像のセグメンテーションマスクを予測するためにモデルがどのように学習しているかを示すGIFをプロットしました。

Figure 8: A gif showing the progression of segmentation masks predicted by the vision transformer for image segmentation model. Source: Author(s).

初期のトレーニングエポックで興味深いことがわかりました。予測されたセグメンテーションマスクには、いくつかの奇妙なブロッキングアーティファクトがあります。これについて考えられる唯一の理由は、画像を16×16のパッチに分割しており、わずかなトレーニングエポック後に、モデルがこの16×16パッチが一般的にペットまたは背景ピクセルによってカバーされているかどうかに関する、非常に粗い情報以外に有用な情報を何も学んでいないためです。

Figure 9: The blocking artifacts seen in the predicted segmentation masks when using the vision transformer for image segmentation. Source: Author(s).

基本的なビジョントランスフォーマーの動作を見たので、次は最先端のビジョントランスフォーマーをセグメンテーションタスクに注目しましょう。

SegFormer:トランスフォーマーによる意味的セグメンテーション

SegFormerアーキテクチャは、この論文で2021年に提案されました。上記で説明したトランスフォーマーは、SegFormerアーキテクチャのより単純なバージョンです。

Figure 10: The SegFormer architecture. Source: SegFormer paper (2021) .

SegFormerの最も注目すべき点は次のとおりです。

  1. 16×16のパッチで構成される単一のパッチ画像ではなく、4×4、8×8、16×16、および32×32のパッチで構成される4つの画像セットを生成します。
  2. 1つのトランスフォーマーエンコーダーブロックではなく、4つのトランスフォーマーエンコーダーブロックを使用します。これはモデルアンサンブルのように感じます。
  3. セルフアテンションのプレおよびポストフェーズで畳み込みを使用します。
  4. 位置埋め込みを使用しません。
  5. 各トランスフォーマーブロックは、空間解像度H / 4 x W / 4、H / 8 x W / 8、H / 16 x W / 16、およびH / 32、W / 32でイメージを処理します。
  6. 同様に、空間次元が縮小するとチャンネルが増加します。これは、深層CNNに似ているように感じます。
  7. 複数の空間次元での予測はアップサンプリングされ、デコーダーで結合されます。
  8. MLPはこれらすべての予測を組み合わせて最終予測を提供します。
  9. 最終予測は空間次元H / 4、W / 4にあり、H、Wにはありません。

結論

このシリーズの第4部では、トランスフォーマーアーキテクチャとビジョン・トランスフォーマーについて紹介しました。ビジョン・トランスフォーマーがどのように機能し、ビジョン・トランスフォーマーの通信と計算フェーズに関与する基本的なビルディングブロックについて直感的に理解しました。ビジョン・トランスフォーマーが予測セグメンテーションマスクを行うために採用したユニークなパッチベースのアプローチを見ました。そして、予測を組み合わせるために使用される方法についても見ました。

ビジョン・トランスフォーマーを使用して示された実験をレビューし、Deep CNNアプローチとの結果を比較することができました。私たちのビジョン・トランスフォーマーは最先端ではありませんが、かなり良い結果を出すことができました。SegFormerなどの最新のアプローチの一部を紹介しました。

トランスフォーマーは、Deep CNNベースのアプローチと比較して、より多くの移動部品を持ち、より複雑です。 FLOPsの観点から見ると、トランスフォーマーはより効率的であることを約束しています。トランスフォーマーでは、実際に計算が重いのはnn.Linearレイヤーだけです。これは、ほとんどのアーキテクチャで最適化された行列乗算を使用して実装されています。このアーキテクチャの単純さにより、トランスフォーマーは、Deep CNNベースのアプローチと比較して最適化や高速化がより簡単であることを約束しています。

最後までお読みいただきありがとうございます!PyTorchで効率的な画像セグメンテーションに関するこのシリーズをお楽しみいただけたことをうれしく思います。ご質問やコメントがある場合は、コメント欄にご自由にお書きください。

さらに読む

この記事では、アテンションメカニズムの詳細は範囲外です。また、アテンションメカニズムを詳しく理解するために参照できる高品質のリソースがたくさんあります。以下は、強くお勧めするいくつかのリソースです。

  1. The Illustrated Transformer
  2. NanoGPT from scratch using PyTorch

以下は、ビジョン・トランスフォーマーに関する詳細な情報を提供する記事へのリンクです。

  1. Implementing Vision Transformer (ViT) in PyTorch:この記事では、PyTorchで画像分類のためのビジョン・トランスフォーマーを実装する方法について詳しく説明しています。注目すべきは、彼らの実装がeinopsを使用していることです。このため、コードの可読性のためにeinopsを学習して使用することをお勧めしますが、教育に焦点を当てた演習であるため、私たちはネイティブのPyTorchオペレータを使用してテンソルの次元を並べ替えるようにしています。また、著者がLinearレイヤーの代わりにConv2dを使用している場所がいくつかあります。私たちは、畳み込み層を使用せずにビジョン・トランスフォーマーの実装を構築することを望みました。
  2. Vision Transformer: AI Summer
  3. Implementing SegFormer in PyTorch

We will continue to update VoAGI; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

機械学習

PaLM AI | Googleの自家製生成AI

イントロダクション OpenAIによるGPT(Generative Pre-trained Transformers)モデル、特にChatGPTなどのような生成型AIモデ...

人工知能

「AIガバナンスにおけるステークホルダー分析の包括的ガイド(パート2)」

「著者注:本記事はAIガバナンスにおけるステークホルダー分析の包括的なガイドのパート2として書かれていますパート1はこち...

機械学習

「機械学習 vs AI vs ディープラーニング vs ニューラルネットワーク:違いは何ですか?」

テクノロジーの急速な進化は、ビジネスが効率化のために洗練されたアルゴリズムにますます頼ることで、私たちの日常生活を形...

機械学習

「プリズマーに会いましょう:専門家のアンサンブルを持つオープンソースのビジョン-言語モデル」

最近の多くのビジョン言語モデルは、非常に注目すべき多様な生成能力を示しています。しかし、通常、それらは膨大なモデルと...

データサイエンス

「ワードエンベディング:より良い回答のためにチャットボットに文脈を与える」

ワードエンベディングとChatGPTを使用してエキスパートボットを構築する方法を学びましょうワードベクトルの力を活用して、チ...

人工知能

「AIが航空会社のコントレイルによる気候への影響を軽減するのに役立っている方法」

「私たちはAIを使用して、航空会社がコントレイルの発生が少ないルートを選択するのを支援し、飛行の環境への影響を最小限に...