ð€ Transformersã䜿çšããŠãç»ååé¡ã®ããã«ViTã埮調æŽãã
'ð€ Transformersã䜿ã£ãŠãç»ååé¡ã®ããã«ViTã埮調æŽãã'
ãã©ã³ã¹ãã©ãŒããŒããŒã¹ã®ã¢ãã«ãNLPãé©åœåããããã«ãæã ã¯ä»ãããããä»ã®ããŸããŸãªé åã«é©çšããè«æã®ççºãç®æããŠããŸãããã®äžã§ãæãé©åœçãªãã®ã®äžã€ããVision TransformerïŒViTïŒãã§ããããã¯ãGoogle Brainã®ç 究ããŒã ã«ãã£ãŠ2021幎6æã«çŽ¹ä»ãããŸããã
ãã®è«æã§ã¯ãæãããŒã¯ã³åããããã«ç»åãããŒã¯ã³åããæ¹æ³ãæ¢æ±ããŠãããããã«ãã£ãŠãã©ã³ã¹ãã©ãŒããŒã¢ãã«ã«ãã¬ãŒãã³ã°çšã®ããŒã¿ãšããŠæž¡ãããšãã§ããŸããå®éã«ã¯éåžžã«ã·ã³ãã«ãªæŠå¿µã§ã…
- ç»åããµãç»åãããã®ã°ãªããã«åå²ãã
- åããããç·åœ¢å€æã§åã蟌ã
- ååã蟌ãŸããããããããŒã¯ã³ãšãªããåã蟌ãŸãããããã®ã·ãŒã±ã³ã¹ãã¢ãã«ã«æž¡ããã
äžèšã®æé ãå®è¡ãããšãNLPã®ã¿ã¹ã¯ãšåæ§ã«ãã©ã³ã¹ãã©ãŒããŒãäºååŠç¿ããã³åŸ®èª¿æŽããããšãã§ããããšãããããŸããããªã䟿å©ã§ã ðã
- BERT 101 – ææ°ã®NLPã¢ãã«ã®è§£èª¬
- ð€ Transformersã«ãããŠå¶çŽä»ãããŒã ãµãŒããçšããããã¹ãçæã®ã¬ã€ã
- Hugging Face TransformersãšAWS Inferentiaã䜿çšããŠãBERTæšè«ãé«éåãã
ãã®ããã°ãã¹ãã§ã¯ãð€ datasets
ã䜿çšããŠç»ååé¡ããŒã¿ã»ãããããŠã³ããŒãããã³åŠçããããã䜿çšããŠäºååŠç¿æžã¿ã® ViT ã ð€ transformers
ã䜿çšããŠåŸ®èª¿æŽããæ¹æ³ã«ã€ããŠèª¬æããŸãã
ãŸãã¯ããããã®ããã±ãŒãžãã€ã³ã¹ããŒã«ããŸãããã
pip install datasets transformers
ããŒã¿ã»ããã®èªã¿èŸŒã¿
ãŸãã¯ãå°èŠæš¡ãªç»ååé¡ããŒã¿ã»ãããèªã¿èŸŒãã§ããã®æ§é ã確èªããŸãããã
ç§ãã¡ã¯ãbeansããšããããŒã¿ã»ããã䜿çšããŸããããã¯ãå¥åº·ãªè±ã®èãšç æ°ã®è±ã®èã®åçã®ã³ã¬ã¯ã·ã§ã³ã§ããð
from datasets import load_dataset
ds = load_dataset('beans')
ds
ãbeansãããŒã¿ã»ããã®ãtrainãã¹ããªãããã400çªç®ã®äŸãèŠãŠã¿ãŸããããããŒã¿ã»ããã®åäŸã«ã¯3ã€ã®ç¹åŸŽãããããšã«æ³šæããŠãã ããïŒ
image
ïŒPILã€ã¡ãŒãžimage_file_path
ïŒãimageããšããŠããŒããããã€ã¡ãŒãžãã¡ã€ã«ã®ãã¹ïŒstr
ïŒlabels
ïŒã©ãã«ã®æŽæ°è¡šçŸã§ããdatasets.ClassLabel
ãã£ãŒãã£ïŒåŸã§æååã¯ã©ã¹åãååŸããæ¹æ³ãèŠãŠãããŸãã®ã§ãå¿é ãªãïŒïŒ
ex = ds['train'][400]
ex
{
'image': <PIL.JpegImagePlugin ...>,
'image_file_path': '/root/.cache/.../bean_rust_train.4.jpg',
'labels': 1
}
ç»åãèŠãŠã¿ãŸããã ð
image = ex['image']
image
ééããªãèã£ã±ã§ãïŒã§ããäœã®èã£ã±ã§ããããïŒ ð
ãã®ããŒã¿ã»ããã®ãlabelsãç¹åŸŽã¯datasets.features.ClassLabel
ã§ããããããã®äŸã®ã©ãã«IDã«å¯Ÿå¿ããååã調ã¹ãããã«äœ¿çšã§ããŸãã
ãŸãã¯ããlabelsãã®ç¹åŸŽå®çŸ©ã«ã¢ã¯ã»ã¹ããŸãããã
labels = ds['train'].features['labels']
labels
ClassLabel(num_classes=3, names=['angular_leaf_spot', 'bean_rust', 'healthy'], names_file=None, id=None)
ããŠãäŸã®ã¯ã©ã¹ã©ãã«ãåºåããŠã¿ãŸããããããã¯ãClassLabel
ã®int2str
é¢æ°ã䜿çšããããšã§è¡ãããšãã§ããŸãããã®é¢æ°ã¯ãã¯ã©ã¹ã®æŽæ°è¡šçŸãæž¡ããŠå¯Ÿå¿ããæååã©ãã«ã調ã¹ãããšãã§ããŸãã
labels.int2str(ex['labels'])
'bean_rust'
äžèšã®ç»åã¯ãè±ã®èããBean Rustããšããæ·±å»ãªç æ°ã«ææããŠããããšãããããŸãã ð¢
åã¯ã©ã¹ããããã€ãã®äŸãã°ãªããã§è¡šç€ºããããã®é¢æ°ãäœæããŸããããããã«ãããäœæ¥å 容ãããè¯ãææ¡ããããšãã§ããŸãã
import random
from PIL import ImageDraw, ImageFont, Image
def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):
w, h = size
labels = ds['train'].features['labels'].names
grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
draw = ImageDraw.Draw(grid)
font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24)
for label_id, label in enumerate(labels):
# ããŒã¿ã»ãããåäžã®ã©ãã«ã§ãã£ã«ã¿ãªã³ã°ããã·ã£ããã«ããŠããã€ãã®ãµã³ãã«ãååŸããŸã
ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))
# ãã®ã©ãã«ã®äŸãäžåã«ããããããŸã
for i, example in enumerate(ds_slice):
image = example['image']
idx = examples_per_class * label_id + i
box = (idx % examples_per_class * w, idx // examples_per_class * h)
grid.paste(image.resize(size), box=box)
draw.text(box, label, (255, 255, 255), font=font)
return grid
show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)
ããŒã¿ã»ããã®åã¯ã©ã¹ããããã€ãã®äŸãå«ãã°ãªãã
èŠãŠãããšãããããããããã«ã
- Angular Leaf Spot: äžèŠåãªè¶è²ããããããããŸã
- Bean Rust: çœé»è²ã®ç°ã§å²ãŸããå圢ã®è¶è²ãæç¹ããããŸã
- Healthy: …å¥åº·ããã§ã ð€·ââïž
ViTç¹åŸŽæœåºåšã®èªã¿èŸŒã¿
ä»ãç§ãã¡ã¯ç»åã®èŠãç®ãç¥ãã解決ããããšããŠããåé¡ãããããç解ããŠããŸããããŠããããã®ç»åãã¢ãã«ã«é©çšããæ¹æ³ãèŠãŠã¿ãŸãããïŒ
ViTã¢ãã«ããã¬ãŒãã³ã°ããéã«ã¯ããããã®ã¢ãã«ã«äŸçµŠãããç»åã«ç¹å®ã®å€æãé©çšãããŸããééã£ãå€æãç»åã«é©çšãããšãã¢ãã«ã¯äœãèŠãŠããã®ãç解ã§ããŸããïŒ ðŒ â¡ïž ð¢
æ£ããå€æãé©çšããããã«ã¯ã䜿çšããäºå®ã®äºååŠç¿ã¢ãã«ãšäžç·ã«ä¿åãããèšå®ã§åæåãããViTFeatureExtractor
ã䜿çšããŸããä»åã¯ãgoogle/vit-base-patch16-224-in21kã¢ãã«ã䜿çšããäºå®ãªã®ã§ãHugging Face Hubãããã®ç¹åŸŽæœåºåšãèªã¿èŸŒã¿ãŸãããã
from transformers import ViTFeatureExtractor
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
ç¹åŸŽæœåºåšã®èšå®ã衚瀺ããã«ã¯ããããå°å·ããŸãã
ViTFeatureExtractor {
"do_normalize": true,
"do_resize": true,
"feature_extractor_type": "ViTFeatureExtractor",
"image_mean": [
0.5,
0.5,
0.5
],
"image_std": [
0.5,
0.5,
0.5
],
"resample": 2,
"size": 224
}
ç»åãåŠçããã«ã¯ãåçŽã«ç¹åŸŽæœåºåšã®ã³ãŒã«é¢æ°ã«æž¡ããŸããããã«ãããã¢ãã«ã«æž¡ãããã®æ°å€è¡šçŸã§ããpixel values
ãå«ãdictãè¿ãããŸãã
ããã©ã«ãã§ã¯NumPyé
åãååŸãããŸãããreturn_tensors='pt'
åŒæ°ãè¿œå ãããšãtorch
ãã³ãœã«ãè¿ãããŸãã
feature_extractor(image, return_tensors='pt')
以äžã®ãããªçµæãåŸãããŸãã
{
'pixel_values': tensor([[[[ 0.2706, 0.3255, 0.3804, ...]]]])
}
…ãã³ãœã«ã®åœ¢ç¶ã¯(1, 3, 224, 224)
ã§ãã
ããŒã¿ã»ããã®åŠç
ç»åã®èªã¿èŸŒã¿ãšå€æãçµã¿åãããŠåäžã®ããŒã¿ã»ããã®äŸãåŠçããããã®é¢æ°ãäœæããŸãããã
def process_example(example):
inputs = feature_extractor(example['image'], return_tensors='pt')
inputs['labels'] = example['labels']
return inputs
process_example(ds['train'][0])
{
'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078, ..., ]]]]),
'labels': 0
}
ds.map
ãåŒã³åºããŠäžåºŠã«ãã¹ãŠã®äŸã«é©çšããããšãã§ããŸãããããã¯éåžžã«é
ããªãå ŽåããããŸããç¹ã«å€§ããªããŒã¿ã»ããã䜿çšããå Žåã§ãã代ããã«ãããŒã¿ã»ããã«ãã©ã³ã¹ãã©ãŒã ãé©çšããããšãã§ããŸãããã©ã³ã¹ãã©ãŒã ã¯ãäŸãã€ã³ããã¯ã¹ããéã«ã®ã¿é©çšãããŸãã
ãã ããds.with_transform
ãæåŸ
ããããã«ãæåŸã®é¢æ°ããããããŒã¿ãåãå
¥ããããã«æŽæ°ããå¿
èŠããããŸãã
ds = load_dataset('beans')
def transform(example_batch):
# PILã€ã¡ãŒãžã®ãªã¹ãããã¯ã»ã«å€ã«å€æããŸã
inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
# ã©ãã«ãå¿ããã«å«ããŠãã ããïŒ
inputs['labels'] = example_batch['labels']
return inputs
ds.with_transform(transform)
ã䜿çšããŠããŒã¿ã»ããã«çŽæ¥é©çšããããšãã§ããŸãã
prepared_ds = ds.with_transform(transform)
ããã§ãããŒã¿ã»ããããäŸãååŸããéã«ããã©ã³ã¹ãã©ãŒã ããªã¢ã«ã¿ã€ã ã«é©çšãããŸãïŒãµã³ãã«ãšã¹ã©ã€ã¹ã®äž¡æ¹ã«é©çšãããããšã瀺ãããŠããŸãïŒ
prepared_ds['train'][0:2]
ä»åãpixel_values
ãã³ãœã«ã®åœ¢ç¶ã¯(2, 3, 224, 224)
ãšãªããŸãã
{
'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078, ..., ]]]]),
'labels': [0, 0]
}
ããŒã¿ã¯åŠçããããã¬ãŒãã³ã°ãã€ãã©ã€ã³ã®èšå®ãéå§ããæºåãã§ããŸããããã®ããã°æçš¿ã§ã¯ð€ã®Trainerã䜿çšããŸãããããã«ã¯ãŸãããã€ãã®ããšãè¡ãå¿ èŠããããŸã:
-
collateé¢æ°ãå®çŸ©ããŸãã
-
è©äŸ¡ææšãå®çŸ©ããŸãããã¬ãŒãã³ã°äžãã¢ãã«ã¯äºæž¬ã®æ£ç¢ºæ§ã§è©äŸ¡ãããã¹ãã§ããããã«å¿ããŠ
compute_metrics
é¢æ°ãå®çŸ©ããå¿ èŠããããŸãã -
äºååŠç¿æžã¿ã®ãã§ãã¯ãã€ã³ããèªã¿èŸŒã¿ãŸããäºååŠç¿æžã¿ã®ãã§ãã¯ãã€ã³ããèªã¿èŸŒã¿ããã¬ãŒãã³ã°ã«é©åã«èšå®ããå¿ èŠããããŸãã
-
ãã¬ãŒãã³ã°ã®èšå®ãå®çŸ©ããŸãã
ã¢ãã«ã埮調æŽããåŸãè©äŸ¡ããŒã¿ã§æ£ããè©äŸ¡ããç»åã®åé¡ãæ£ããåŠç¿ããããšã確èªããŸãã
ããŒã¿ã³ã¬ãŒã¿ãå®çŸ©ãã
ãããã¯èŸæžã®ãªã¹ããšããŠæž¡ããããããããããããããã³ãœã«ã«å±éããŠã¹ã¿ãã¯ããã ãã§ãã
collate_fn
ã¯ãããèŸæžãè¿ãã®ã§ãåŸã§ã¢ãã«ãžã®å
¥åã**ã¢ã³ããã¯
ã§ããŸããâš
import torch
def collate_fn(batch):
return {
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
'labels': torch.tensor([x['labels'] for x in batch])
}
è©äŸ¡ææšãå®çŸ©ãã
datasets
ããã®ç²ŸåºŠææšã¯ãäºæž¬ãšã©ãã«ãæ¯èŒããããã«ç°¡åã«äœ¿çšã§ããŸãã以äžã§ã¯ãTrainer
ã§äœ¿çšãããcompute_metrics
é¢æ°å
ã§ããã䜿çšããæ¹æ³ã瀺ãããŠããŸãã
import numpy as np
from datasets import load_metric
metric = load_metric("accuracy")
def compute_metrics(p):
return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
äºååŠç¿æžã¿ã¢ãã«ãèªã¿èŸŒã¿ãŸããããinitæã«num_labels
ãè¿œå ããããšã§ãã¢ãã«ã¯é©åãªãŠãããæ°ã®åé¡ããããäœæããŸãããŸãã人éãèªã¿ãããã©ãã«ãHubãŠã£ãžã§ããã§äœ¿çšã§ããããã«ãid2label
ãšlabel2id
ã®ãããã³ã°ãå«ããŸãïŒpush_to_hub
ãéžæããå ŽåïŒã
from transformers import ViTForImageClassification
labels = ds['train'].features['labels'].names
model = ViTForImageClassification.from_pretrained(
model_name_or_path,
num_labels=len(labels),
id2label={str(i): c for i, c in enumerate(labels)},
label2id={c: str(i) for i, c in enumerate(labels)}
)
ãã¬ãŒãã³ã°ã®æºåãã»ãŒæŽããŸããïŒæåŸã«å¿
èŠãªã®ã¯ãTrainingArguments
ãå®çŸ©ããŠãã¬ãŒãã³ã°ã®èšå®ãè¡ãããšã§ãã
ãããã®ã»ãšãã©ã¯èªæã§ãããããã§ããªãéèŠãªãã®ã®1ã€ã¯remove_unused_columns=False
ã§ããããã«ãããã¢ãã«ã®åŒã³åºãé¢æ°ã§äœ¿çšãããªãç¹åŸŽãåé€ãããŸããããã©ã«ãã§ã¯True
ã§ãããéåžžã¯æªäœ¿çšã®ç¹åŸŽåãåé€ããã®ãçæ³çã§ãããã¢ãã«ã®åŒã³åºãé¢æ°ã«å
¥åãå±éãããããªããŸããããããç§ãã¡ã®å Žåã¯ã’pixel_values’ãäœæããããã«æªäœ¿çšã®ç¹åŸŽïŒç¹ã«’image’ïŒãå¿
èŠã§ãã
èšãããããšã¯ãremove_unused_columns=False
ãèšå®ãå¿ãããšåé¡ãçºçããŸãã
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="./vit-base-beans",
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
fp16=True,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=False,
report_to='tensorboard',
load_best_model_at_end=True,
)
ããããã¹ãŠã®ã€ã³ã¹ã¿ã³ã¹ã Trainer ã«æž¡ãããšãã§ãããã¬ãŒãã³ã°ãéå§ããæºåãæŽããŸããïŒ
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
compute_metrics=compute_metrics,
train_dataset=prepared_ds["train"],
eval_dataset=prepared_ds["validation"],
tokenizer=feature_extractor,
)
ãã¬ãŒãã³ã° ð
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()
è©äŸ¡ ð
metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
ããã«è©äŸ¡çµæããããŸã – Cool beans! ããããªãããèšã£ãŠãããªããã°ãªããŸããã§ããã
***** eval metrics *****
epoch = 4.0
eval_accuracy = 0.985
eval_loss = 0.0637
eval_runtime = 0:00:02.13
eval_samples_per_second = 62.356
eval_steps_per_second = 7.97
æåŸã«ãããæãã®ã§ããã°ãã¢ãã«ãããã«ããã·ã¥ããããšãã§ããŸãããã¬ãŒãã³ã°ã®èšå®ã§ push_to_hub=True
ãæå®ããå Žåã«ã¯ãããã§ããã·ã¥ããŸãããã ããããã«ããã·ã¥ããããã«ã¯ãgit-lfs ãã€ã³ã¹ããŒã«ããŠãããHugging Face ã¢ã«ãŠã³ãã«ãã°ã€ã³ããŠããå¿
èŠããããŸãïŒhuggingface-cli login
ã䜿çšããŠãã°ã€ã³ã§ããŸãïŒã
kwargs = {
"finetuned_from": model.config._name_or_path,
"tasks": "image-classification",
"dataset": 'beans',
"tags": ['image-classification'],
}
if training_args.push_to_hub:
trainer.push_to_hub('ð» cheers', **kwargs)
else:
trainer.create_model_card(**kwargs)
çµæã®ã¢ãã«ã¯ nateraw/vit-base-beans ã«å ±æãããŸããããããããè±ã®èã®åçãæå ã«ãããšã¯æããŸããã®ã§ãè©ŠããŠã¿ãããã®ããã€ãã®äŸãè¿œå ããŸããïŒ ð
We will continue to update VoAGI; if you have any questions or suggestions, please contact us!
Was this article helpful?
93 out of 132 found this helpful
Related articles
- Hugging Faceã§ã®Decision Transformersã®çŽ¹ä» ð€
- ~èªåèªèº«ã~ ç¹°ãè¿ããªã
- Habana LabsãšHugging FaceãææºããTransformerã¢ãã«ã®ãã¬ãŒãã³ã°ãå éåãã
- CO2æåºéãšð€ããïŒãªãŒãã£ã³ã°ã»ã¶ã»ãã£ãŒãž
- æè²ã®ããã®Hugging Faceãã玹ä»ããŸã ð€
- Habana Gaudiã§ã®Transformersã®å§ãæ¹
- KiliãšHuggingFace AutoTrainã䜿çšããæèŠåé¡