CLIPSegによるゼロショット画像セグメンテーション
'Zero-shot image segmentation using CLIPSeg.'
このガイドでは、🤗 transformers
を使用して、ゼロショットの画像セグメンテーションモデルであるCLIPSegを使用する方法を紹介します。CLIPSegは、ロボットの知覚、画像補完など、さまざまなタスクに使用できるラフなセグメンテーションマスクを作成します。より正確なセグメンテーションマスクが必要な場合は、Segments.aiでCLIPSegの結果を改善する方法も紹介します。
画像セグメンテーションは、コンピュータビジョンの分野でよく知られたタスクです。これにより、コンピュータは画像内の物体を知るだけでなく(分類)、画像内の物体の位置を知ることもできます(検出)、さらには物体の輪郭も知ることができます。物体の輪郭を知ることは、ロボット工学や自動運転などの分野では重要です。たとえば、ロボットは物体の形状を正しく把握するために、その形状を知る必要があります。セグメンテーションは、画像補完と組み合わせることもでき、ユーザーが画像のどの部分を置き換えたいかを説明することができます。
ほとんどの画像セグメンテーションモデルの制限の1つは、固定されたカテゴリのリストでのみ機能するということです。たとえば、オレンジでトレーニングされたセグメンテーションモデルを使用して、リンゴをセグメント化することはできません。セグメンテーションモデルに追加のカテゴリを教えるには、新しいカテゴリのデータをラベル付けし、新しいモデルをトレーニングする必要があります。これは費用と時間がかかる場合があります。しかし、さらなるトレーニングなしにほとんどどのような種類のオブジェクトでもセグメント化できるモデルがあったらどうでしょうか?それがCLIPSeg、ゼロショットのセグメンテーションモデルが達成するものです。
- インテルのサファイアラピッズを使用してPyTorch Transformersを高速化する – パート1
- ゲーム開発のためのAI:5日間で農業ゲームを作成するパート1
- ゲーム開発のためのAI:5日間で農業ゲームを作成するパート2
現時点では、CLIPSegにはまだ制限があります。たとえば、モデルは352 x 352ピクセルの画像を使用するため、出力はかなり低解像度です。したがって、モダンなカメラの画像を使用すると、ピクセルパーフェクトな結果を期待することはできません。より正確なセグメンテーションを必要とする場合、前のブログ記事で示したように、最新のセグメンテーションモデルを微調整することができます。その場合、CLIPSegを使用してラフなラベルを生成し、Segments.aiなどのラベリングツールでそれらを調整することができます。それについて説明する前に、まずCLIPSegの動作を見てみましょう。
CLIP: CLIPSegの背後にある魔法のモデル
CLIP(Contrastive Language–Image Pre-training)は、OpenAIが2021年に開発したモデルです。CLIPに画像またはテキストの一部を与えると、CLIPは入力の抽象的な表現を出力します。この抽象的な表現、または埋め込みとも呼ばれるものは、実際にはベクトル(数値のリスト)です。このベクトルは、高次元空間のポイントと考えることができます。CLIPは、似たような画像とテキストの表現も似たようにするようにトレーニングされています。つまり、画像とそれに合致するテキストの説明を入力すると、画像とテキストの表現が似ている(つまり、高次元のポイントが近くにある)ことになります。
最初はあまり役に立たないように思えるかもしれませんが、実際には非常に強力です。例えば、CLIPを使用して訓練されたことがないタスクで画像を分類する方法を簡単に見てみましょう。画像を分類するには、画像と選択肢となる異なるカテゴリをCLIPに入力します(例えば、画像と「りんご」、「オレンジ」などの単語を入力します)。CLIPは、画像と各カテゴリの埋め込みを返します。今、画像の埋め込みに最も近いカテゴリの埋め込みを確認するだけです。これで完了です!まるで魔法のようですね。
CLIPを使用した画像分類の例(出典)。
さらに、CLIPは分類だけでなく、画像検索(これが分類と似ていることがわかりますか?)、テキストから画像への変換モデル(DALL-E 2はCLIPで動作します)、物体検出(OWL-ViT)などにも使用できます。そして、私たちにとって最も重要なのは、画像セグメンテーションです。これでCLIPが機械学習において本当に画期的なものである理由がお分かりいただけるでしょう。
CLIPが非常にうまく機能する理由は、モデルがテキストのキャプション付きの膨大なデータセットでトレーニングされたからです。そのデータセットには、インターネットから取得した4億枚の画像テキストペアが含まれています。これらの画像にはさまざまなオブジェクトや概念が含まれており、CLIPはそれぞれのオブジェクトに対して表現を生成するのに優れています。
CLIPSeg: CLIPによる画像セグメンテーション
CLIPSegは、CLIPの表現を使用して画像セグメンテーションマスクを作成するモデルです。Timo LüddeckeさんとAlexander Eckerさんによって公開されました。彼らは、CLIPモデルを凍結したまま、TransformerベースのデコーダをCLIPモデルの上にトレーニングすることで、ゼロショット画像セグメンテーションを達成しました。デコーダは、画像のCLIP表現とセグメンテーションしたい対象のCLIP表現を入力として受け取り、これらの2つの入力を使用して、CLIPSegデコーダは2値のセグメンテーションマスクを作成します。より詳しく言うと、デコーダはセグメンテーションしたい画像の最終的なCLIP表現だけでなく、CLIPのいくつかのレイヤーの出力も使用します。
ソース
デコーダは、PhraseCutデータセットでトレーニングされています。このデータセットには、340,000以上のフレーズと対応する画像セグメンテーションマスクが含まれています。著者たちはまた、データセットのサイズを拡大するためにさまざまな拡張方法も試みました。ここでの目標は、データセットに存在するカテゴリだけでなく、未知のカテゴリもセグメンテーションできるようにすることです。実験の結果、デコーダは未知のカテゴリにも対応できることが示されています。
CLIPSegの興味深い特徴の1つは、クエリ(セグメンテーションしたい画像)とプロンプト(画像内のセグメンテーションしたい対象)の両方がCLIPの埋め込みとして入力されることです。プロンプトのCLIPの埋め込みは、テキスト(カテゴリ名)から来ることもありますが、または他の画像から来ることもあります。つまり、CLIPSegにオレンジの例の画像を与えることで、写真内のオレンジをセグメンテーションできます。
この「ビジュアルプロンプティング」と呼ばれる技術は、セグメンテーションしたい対象を説明するのが難しい場合に非常に役立ちます。たとえば、Tシャツの写真の中のロゴをセグメンテーションしたい場合、ロゴの形状を簡単に説明することは難しいですが、CLIPSegを使用すると、単にロゴの画像をプロンプトとして使用することができます。
CLIPSegの論文には、ビジュアルプロンプティングの効果を向上させるいくつかのヒントが含まれています。クエリ画像をクロッピングする(セグメンテーションしたいオブジェクトのみを含むようにする)ことが非常に助けになることがわかりました。また、クエリ画像の背景をぼかしたり暗くしたりすることも少し助けになります。次のセクションでは、🤗 transformers
を使用して自分自身でビジュアルプロンプティングを試す方法を示します。
Hugging Face TransformersでCLIPSegを使用する
Hugging Face Transformersを使用すると、事前学習済みのCLIPSegモデルを簡単にダウンロードして実行できます。まず、transformersをインストールしてみましょう。
!pip install -q transformers
モデルをダウンロードするには、単にインスタンス化します。
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
それでは、セグメンテーションを試すために画像をロードしましょう。Calum Lewisさんが撮影したおいしい朝食の写真を選びましょう。
from PIL import Image
import requests
url = "https://unsplash.com/photos/8Nc_oQsc2qQ/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjcxMjAwNzI0&force=true&w=640"
image = Image.open(requests.get(url, stream=True).raw)
image
テキストプロンプティング
まず、セグメンテーションしたいテキストカテゴリを定義しましょう。
prompts = ["cutlery", "pancakes", "blueberries", "orange juice"]
入力が揃ったので、それらを処理してモデルに入力しましょう。
import torch
inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
# 予測
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits.unsqueeze(1)
最後に、出力を視覚化しましょう。
import matplotlib.pyplot as plt
_, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))];
[ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)];
視覚的なプロンプト
前述のように、カテゴリ名の代わりに画像を入力プロンプトとして使用することもできます。これは、セグメント化したい対象を簡単に説明することができない場合に特に役立ちます。この例では、Daniel Hooperによって撮影されたコーヒーカップの写真を使用します。
url = "https://unsplash.com/photos/Ki7sAc8gOGE/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MTJ8fGNvZmZlJTIwdG8lMjBnb3xlbnwwfHx8fDE2NzExOTgzNDQ&force=true&w=640"
prompt = Image.open(requests.get(url, stream=True).raw)
prompt
今、入力画像とプロンプト画像を処理し、モデルに入力します。
encoded_image = processor(images=[image], return_tensors="pt")
encoded_prompt = processor(images=[prompt], return_tensors="pt")
# 予測
with torch.no_grad():
outputs = model(**encoded_image, conditional_pixel_values=encoded_prompt.pixel_values)
preds = outputs.logits.unsqueeze(1)
preds = torch.transpose(preds, 0, 1)
そして、前と同様に結果を可視化することができます。
_, ax = plt.subplots(1, 2, figsize=(6, 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
ax[1].imshow(torch.sigmoid(preds[0]))
最後に、論文で説明されている視覚的なプロンプトのヒントを使用して、最後の試みをしてみましょう。つまり、画像をトリミングし、背景を暗くします。
url = "https://i.imgur.com/mRSORqz.jpg"
alternative_prompt = Image.open(requests.get(url, stream=True).raw)
alternative_prompt
encoded_alternative_prompt = processor(images=[alternative_prompt], return_tensors="pt")
# 予測
with torch.no_grad():
outputs = model(**encoded_image, conditional_pixel_values=encoded_alternative_prompt.pixel_values)
preds = outputs.logits.unsqueeze(1)
preds = torch.transpose(preds, 0, 1)
_, ax = plt.subplots(1, 2, figsize=(6, 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
ax[1].imshow(torch.sigmoid(preds[0]))
この場合、結果はほぼ同じです。これは、元の画像でコーヒーカップが背景から十分に分離されているためである可能性があります。
Segments.aiで画像の事前ラベル付けをするためにCLIPSegを使用する
ご覧のように、CLIPSegの結果は少しぼやけており、非常に低解像度です。より良い結果を得たい場合は、前のブログ記事で説明されているように、最新のセグメンテーションモデルを微調整することができます。モデルを微調整するためには、ラベル付きデータが必要です。このセクションでは、画像セグメンテーションのためのスマートなラベリングツールを備えたセグメンテーションプラットフォームであるSegments.aiを使用して、いくつかの粗いセグメンテーションマスクを作成し、それを改良する方法を紹介します。
まず、https://segments.ai/join でアカウントを作成し、Segments Python SDKをインストールします。次に、APIキーを使用してSegments.ai Pythonクライアントを初期化します。このキーはアカウントページで見つけることができます。
!pip install -q segments-ai
from segments import SegmentsClient
from getpass import getpass
api_key = getpass('APIキーを入力してください: ')
segments_client = SegmentsClient(api_key)
次に、Segmentsクライアントを使用してデータセットから画像をロードしましょう。a2d2自動運転データセットを使用します。また、これらの手順に従って独自のデータセットを作成することもできます。
samples = segments_client.get_samples("admin-tobias/clipseg")
# 最後の画像を例として使用します
sample = samples[1]
image = Image.open(requests.get(sample.attributes.image.url, stream=True).raw)
image
また、データセット属性からカテゴリ名を取得する必要があります。
dataset = segments_client.get_dataset("admin-tobias/clipseg")
category_names = [category.name for category in dataset.task_attributes.categories]
これで、以前と同様に画像にCLIPSegを使用することができます。今回は、出力を入力画像のサイズに合わせるために、出力を拡大します。
from torch import nn
inputs = processor(text=category_names, images=[image] * len(category_names), padding="max_length", return_tensors="pt")
# 予測
with torch.no_grad():
outputs = model(**inputs)
# 出力をリサイズ
preds = nn.functional.interpolate(
outputs.logits.unsqueeze(1),
size=(image.size[1], image.size[0]),
mode="bilinear"
)
そして結果を再度視覚化することができます。
len_cats = len(category_names)
_, ax = plt.subplots(1, len_cats + 1, figsize=(3*(len_cats + 1), 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len_cats)];
[ax[i+1].text(0, -15, category_name) for i, category_name in enumerate(category_names)];
今度は予測結果を1つのセグメンテーション画像に結合する必要があります。各パッチに対して、最もシグモイド値が大きいカテゴリを取ることでこれを行います。また、ある閾値以下のすべての値をカウントしないようにします。
threshold = 0.1
flat_preds = torch.sigmoid(preds.squeeze()).reshape((preds.shape[0], -1))
# 閾値を持つダミーの "未ラベル付け" マスクを初期化する
flat_preds_with_treshold = torch.full((preds.shape[0] + 1, flat_preds.shape[-1]), threshold)
flat_preds_with_treshold[1:preds.shape[0]+1,:] = flat_preds
# 各ピクセルのトップマスクインデックスを取得する
inds = torch.topk(flat_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))
結果をすばやく視覚化しましょう。
plt.imshow(inds)
最後に、予測をSegments.aiにアップロードする必要があります。そのために、ビットマップをpngファイルに変換し、このファイルをSegmentsにアップロードし、最後にラベルをサンプルに追加します。
from segments.utils import bitmap2file
import numpy as np
inds_np = inds.numpy().astype(np.uint32)
unique_inds = np.unique(inds_np).tolist()
f = bitmap2file(inds_np, is_segmentation_bitmap=True)
asset = segments_client.upload_asset(f, "clipseg_prediction.png")
attributes = {
'format_version': '0.1',
'annotations': [{"id": i, "category_id": i} for i in unique_inds if i != 0],
'segmentation_bitmap': { 'url': asset.url },
}
segments_client.add_label(sample.uuid, 'ground-truth', attributes)
Segments.aiでアップロードされた予測を見ると、完璧ではないことがわかります。ただし、最も大きな間違いを手動で修正し、修正済みのデータセットを使用してCLIPSegよりも優れたモデルをトレーニングすることができます。
結論
CLIPSegは、テキストと画像のプロンプトの両方で動作するゼロショットセグメンテーションモデルです。このモデルはCLIPにデコーダを追加し、ほぼすべてをセグメンテーションできます。ただし、出力のセグメンテーションマスクは現時点では非常に低解像度ですので、精度が重要な場合は別のセグメンテーションモデルを微調整する必要があります。
現在、ゼロショットセグメンテーションに関するさらなる研究が行われており、近い将来さらに多くのモデルが追加されることが期待されます。一つの例はGroupViTで、既に🤗 Transformersで利用可能です。セグメンテーション研究の最新ニュースについては、Twitterで私たちをフォローしてください:@TobiasCornille、@NielsRogge、@huggingface。
最先端のセグメンテーションモデルを微調整する方法に興味がある場合、前のブログ記事をチェックしてください:https://huggingface.co/blog/fine-tune-segformer。
We will continue to update VoAGI; if you have any questions or suggestions, please contact us!
Was this article helpful?
93 out of 132 found this helpful
Related articles