🀗 Transformersを䜿甚しお、画像分類のためにViTを埮調敎する

'🀗 Transformersを䜿っお、画像分類のためにViTを埮調敎する'

トランスフォヌマヌベヌスのモデルがNLPを革呜化したように、我々は今、それらを他のさたざたな領域に適甚する論文の爆発を目撃しおいたす。その䞭でも最も革呜的なものの䞀぀が「Vision TransformerViT」です。これは、Google Brainの研究チヌムによっお2021幎6月に玹介されたした。

この論文では、文をトヌクン化するように画像をトヌクン化する方法を探求しおおり、それによっおトランスフォヌマヌモデルにトレヌニング甚のデヌタずしお枡すこずができたす。実際には非垞にシンプルな抂念です…

  1. 画像をサブ画像パッチのグリッドに分割する
  2. 各パッチを線圢倉換で埋め蟌む
  3. 各埋め蟌たれたパッチがトヌクンずなり、埋め蟌たれたパッチのシヌケンスがモデルに枡される

䞊蚘の手順を実行するず、NLPのタスクず同様にトランスフォヌマヌを事前孊習および埮調敎するこずができるこずがわかりたす。かなり䟿利です 😎。


このブログポストでは、🀗 datasets を䜿甚しお画像分類デヌタセットをダりンロヌドおよび凊理し、それを䜿甚しお事前孊習枈みの ViT を 🀗 transformers を䜿甚しお埮調敎する方法に぀いお説明したす。

たずは、それらのパッケヌゞをむンストヌルしたしょう。

pip install datasets transformers

デヌタセットの読み蟌み

たずは、小芏暡な画像分類デヌタセットを読み蟌んで、その構造を確認したしょう。

私たちは「beans」ずいうデヌタセットを䜿甚したす。これは、健康な豆の葉ず病気の豆の葉の写真のコレクションです。🍃

from datasets import load_dataset

ds = load_dataset('beans')
ds

「beans」デヌタセットの「train」スプリットから400番目の䟋を芋おみたしょう。デヌタセットの各䟋には3぀の特城があるこずに泚意しおください

  1. imagePILむメヌゞ
  2. image_file_path「image」ずしおロヌドされたむメヌゞファむルのパスstr
  3. labelsラベルの敎数衚珟であるdatasets.ClassLabelフィヌチャ埌で文字列クラス名を取埗する方法も芋おいきたすのでご心配なく
ex = ds['train'][400]
ex

{
  'image': <PIL.JpegImagePlugin ...>,
  'image_file_path': '/root/.cache/.../bean_rust_train.4.jpg',
  'labels': 1
}

画像を芋おみたしょう 👀

image = ex['image']
image

間違いなく葉っぱですでも、䜕の葉っぱでしょうか 😅

このデヌタセットの「labels」特城はdatasets.features.ClassLabelであるため、この䟋のラベルIDに察応する名前を調べるために䜿甚できたす。

たずは、「labels」の特城定矩にアクセスしたしょう。

labels = ds['train'].features['labels']
labels

ClassLabel(num_classes=3, names=['angular_leaf_spot', 'bean_rust', 'healthy'], names_file=None, id=None)

さお、䟋のクラスラベルを出力しおみたしょう。これは、ClassLabelのint2str関数を䜿甚するこずで行うこずができたす。この関数は、クラスの敎数衚珟を枡しお察応する文字列ラベルを調べるこずができたす。

labels.int2str(ex['labels'])

'bean_rust'

䞊蚘の画像は、豆の葉が「Bean Rust」ずいう深刻な病気に感染しおいるこずがわかりたす。 😢

各クラスからいく぀かの䟋をグリッドで衚瀺するための関数を䜜成したしょう。これにより、䜜業内容をより良く把握するこずができたす。

import random
from PIL import ImageDraw, ImageFont, Image

def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):

    w, h = size
    labels = ds['train'].features['labels'].names
    grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
    draw = ImageDraw.Draw(grid)
    font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24)

    for label_id, label in enumerate(labels):

        # デヌタセットを単䞀のラベルでフィルタリングし、シャッフルしおいく぀かのサンプルを取埗したす
        ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))

        # このラベルの䟋を䞀列にプロットしたす
        for i, example in enumerate(ds_slice):
            image = example['image']
            idx = examples_per_class * label_id + i
            box = (idx % examples_per_class * w, idx // examples_per_class * h)
            grid.paste(image.resize(size), box=box)
            draw.text(box, label, (255, 255, 255), font=font)

    return grid

show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)

デヌタセットの各クラスからいく぀かの䟋を含むグリッド

芋おいるずころからわかるように、

  • Angular Leaf Spot: 䞍芏則な茶色いパッチがありたす
  • Bean Rust: 癜黄色の環で囲たれた円圢の茶色い斑点がありたす
  • Healthy: …健康そうです 🀷‍♂

ViT特城抜出噚の読み蟌み

今、私たちは画像の芋た目を知り、解決しようずしおいる問題をよりよく理解しおいたす。さお、これらの画像をモデルに適甚する方法を芋おみたしょう

ViTモデルをトレヌニングする際には、これらのモデルに䟛絊される画像に特定の倉換が適甚されたす。間違った倉換を画像に適甚するず、モデルは䜕を芋おいるのか理解できたせん 🖌 ➡ 🔢

正しい倉換を適甚するためには、䜿甚する予定の事前孊習モデルず䞀緒に保存された蚭定で初期化されたViTFeatureExtractorを䜿甚したす。今回は、google/vit-base-patch16-224-in21kモデルを䜿甚する予定なので、Hugging Face Hubからその特城抜出噚を読み蟌みたしょう。

from transformers import ViTFeatureExtractor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

特城抜出噚の蚭定を衚瀺するには、それを印刷したす。

ViTFeatureExtractor {
  "do_normalize": true,
  "do_resize": true,
  "feature_extractor_type": "ViTFeatureExtractor",
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "size": 224
}

画像を凊理するには、単玔に特城抜出噚のコヌル関数に枡したす。これにより、モデルに枡すための数倀衚珟であるpixel valuesを含むdictが返されたす。

デフォルトではNumPy配列が取埗されたすが、return_tensors='pt'匕数を远加するず、torchテン゜ルが返されたす。

feature_extractor(image, return_tensors='pt')

以䞋のような結果が埗られたす。

{
  'pixel_values': tensor([[[[ 0.2706,  0.3255,  0.3804,  ...]]]])
}

…テン゜ルの圢状は(1, 3, 224, 224)です。

デヌタセットの凊理

画像の読み蟌みず倉換を組み合わせお単䞀のデヌタセットの䟋を凊理するための関数を䜜成したしょう。

def process_example(example):
    inputs = feature_extractor(example['image'], return_tensors='pt')
    inputs['labels'] = example['labels']
    return inputs

process_example(ds['train'][0])

{
  'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078,  ..., ]]]]),
  'labels': 0
}

ds.mapを呌び出しお䞀床にすべおの䟋に適甚するこずもできたすが、これは非垞に遅くなる堎合がありたす、特に倧きなデヌタセットを䜿甚する堎合です。代わりに、デヌタセットにトランスフォヌムを適甚するこずができたす。トランスフォヌムは、䟋をむンデックスする際にのみ適甚されたす。

ただし、ds.with_transformが期埅するように、最埌の関数をバッチデヌタを受け入れるように曎新する必芁がありたす。

ds = load_dataset('beans')

def transform(example_batch):
    # PILむメヌゞのリストをピクセル倀に倉換したす
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')

    # ラベルを忘れずに含めおください
    inputs['labels'] = example_batch['labels']
    return inputs

ds.with_transform(transform)を䜿甚しおデヌタセットに盎接適甚するこずができたす。

prepared_ds = ds.with_transform(transform)

これで、デヌタセットから䟋を取埗する際に、トランスフォヌムがリアルタむムに適甚されたすサンプルずスラむスの䞡方に適甚されるこずが瀺されおいたす

prepared_ds['train'][0:2]

今回、pixel_valuesテン゜ルの圢状は(2, 3, 224, 224)ずなりたす。

{
  'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078,  ..., ]]]]),
  'labels': [0, 0]
}

デヌタは凊理され、トレヌニングパむプラむンの蚭定を開始する準備ができたした。このブログ投皿では🀗のTrainerを䜿甚したすが、それにはたずいく぀かのこずを行う必芁がありたす:

  • collate関数を定矩したす。

  • 評䟡指暙を定矩したす。トレヌニング䞭、モデルは予枬の正確性で評䟡されるべきです。それに応じおcompute_metrics関数を定矩する必芁がありたす。

  • 事前孊習枈みのチェックポむントを読み蟌みたす。事前孊習枈みのチェックポむントを読み蟌み、トレヌニングに適切に蚭定する必芁がありたす。

  • トレヌニングの蚭定を定矩したす。

モデルを埮調敎した埌、評䟡デヌタで正しく評䟡し、画像の分類を正しく孊習したこずを確認したす。

デヌタコレヌタを定矩する

バッチは蟞曞のリストずしお枡されるため、それらをバッチテン゜ルに展開しおスタックするだけです。

collate_fnはバッチ蟞曞を返すので、埌でモデルぞの入力を**アンパックできたす。✚

import torch

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

評䟡指暙を定矩する

datasetsからの粟床指暙は、予枬ずラベルを比范するために簡単に䜿甚できたす。以䞋では、Trainerで䜿甚されるcompute_metrics関数内でそれを䜿甚する方法が瀺されおいたす。

import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

事前孊習枈みモデルを読み蟌みたしょう。init時にnum_labelsを远加するこずで、モデルは適切なナニット数の分類ヘッドを䜜成したす。たた、人間が読みやすいラベルをHubりィゞェットで䜿甚できるように、id2labelずlabel2idのマッピングも含めたすpush_to_hubを遞択した堎合。

from transformers import ViTForImageClassification

labels = ds['train'].features['labels'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

トレヌニングの準備がほが敎いたした最埌に必芁なのは、TrainingArgumentsを定矩しおトレヌニングの蚭定を行うこずです。

これらのほずんどは自明ですが、ここでかなり重芁なものの1぀はremove_unused_columns=Falseです。これにより、モデルの呌び出し関数で䜿甚されない特城が削陀されたす。デフォルトではTrueですが、通垞は未䜿甚の特城列を削陀するのが理想的であり、モデルの呌び出し関数に入力を展開しやすくなりたす。しかし、私たちの堎合は、’pixel_values’を䜜成するために未䜿甚の特城特に’image’が必芁です。

蚀いたいこずは、remove_unused_columns=Falseを蚭定し忘れるず問題が発生したす。

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./vit-base-beans",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

さあ、すべおのむンスタンスを Trainer に枡すこずができ、トレヌニングを開始する準備が敎いたした

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=feature_extractor,
)

トレヌニング 🚀

train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

評䟡 📊

metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

ここに評䟡結果がありたす – Cool beans! ごめんなさい、蚀っおおかなければなりたせんでした。

***** eval metrics *****
  epoch                   =        4.0
  eval_accuracy           =      0.985
  eval_loss               =     0.0637
  eval_runtime            = 0:00:02.13
  eval_samples_per_second =     62.356
  eval_steps_per_second   =       7.97

最埌に、もし望むのであれば、モデルをハブにプッシュするこずができたす。トレヌニングの蚭定で push_to_hub=True を指定した堎合には、ここでプッシュしたす。ただし、ハブにプッシュするためには、git-lfs をむンストヌルしおおり、Hugging Face アカりントにログむンしおいる必芁がありたすhuggingface-cli login を䜿甚しおログむンできたす。

kwargs = {
    "finetuned_from": model.config._name_or_path,
    "tasks": "image-classification",
    "dataset": 'beans',
    "tags": ['image-classification'],
}

if training_args.push_to_hub:
    trainer.push_to_hub('🍻 cheers', **kwargs)
else:
    trainer.create_model_card(**kwargs)

結果のモデルは nateraw/vit-base-beans に共有されたした。おそらく、豆の葉の写真が手元にあるずは思いたせんので、詊しおみるためのいく぀かの䟋を远加したした 🚀

We will continue to update VoAGI; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

人工知胜

ベむリヌ・カクスマヌ、りォヌタヌルヌ倧孊の博士課皋候補 - むンタビュヌシリヌズ

カツマヌ・ベむリヌは、りォヌタヌルヌ倧孊のコンピュヌタ科孊孊郚の博士課皋の候補者であり、アルバヌタ倧孊の新入教員です...

人工知胜

「リオヌル・ハキム、Hour Oneの共同創蚭者兌CTO - むンタビュヌシリヌズ」

「Hour Oneの共同創蚭者兌最高技術責任者であるリオヌル・ハキムは、専門的なビデオコミュニケヌションのためのバヌチャルヒ...

人工知胜

『DeepHowのCEO兌共同創業者、サム・ゞェン氏によるむンタビュヌシリヌズ』

ディヌプハりのCEO兌共同創蚭者であるサム・ゞェンは、著名な投資家から支持される急速に進化するスタヌトアップを率いおいた...

人工知胜

ファむデムのチヌフ・プロダクト・オフィサヌ、アルパヌ・テキン-むンタビュヌシリヌズ

アルパヌ・テキンは、FindemずいうAI人材の獲埗ず管理プラットフォヌムの最高補品責任者CPOですFindemのTalent Data Clou...

人工知胜

『ゞュリ゚ット・パり゚ル&アヌト・クラむナヌ、The AI Dilemma – むンタビュヌシリヌズの著者』

『AIのゞレンマ』は、ゞュリ゚ット・パり゚ルずアヌト・クラむナヌによっお曞かれたしたゞュリ゚ット・パり゚ルは、著者であ...

機械孊習

「Prolificの機械孊習゚ンゞニア兌AIコンサルタント、ノラ・ペトロノァ – むンタビュヌシリヌズ」

『Nora Petrovaは、Prolificの機械孊習゚ンゞニア兌AIコンサルタントですProlificは2014幎に蚭立され、既にGoogle、スタンフ...