layout | background-class | body-class | title | summary | category | image | author | tags | github-link | github-id | featured_image_1 | featured_image_2 | accelerator | order | demo-model-link | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
hub_detail |
hub-background |
hub |
Semi-supervised and semi-weakly supervised ImageNet Models |
Billion scale semi-supervised learning for image classification ์์ ์ ์๋ ResNet, ResNext ๋ชจ๋ธ |
researchers |
ssl-image.png |
Facebook AI |
|
facebookresearch/semi-supervised-ImageNet1K-models |
ssl-image.png |
no-image |
cuda-optional |
10 |
import torch
# === ํด์ํ๊ทธ๋ 9์ต 4์ฒ๋ง๊ฐ์ ์ด๋ฏธ์ง๋ฅผ ํ์ฉํ Semi-weakly supervised ์ฌ์ ํ์ต ๋ชจ๋ธ ===
model = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnet18_swsl')
# model = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnet50_swsl')
# model = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext50_32x4d_swsl')
# model = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext101_32x4d_swsl')
# model = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext101_32x8d_swsl')
# model = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext101_32x16d_swsl')
# ================= YFCC100M ๋ฐ์ดํฐ๋ฅผ ํ์ฉํ Semi-supervised ์ฌ์ ํ์ต ๋ชจ๋ธ ==================
# model = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnet18_ssl')
# model = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnet50_ssl')
# model = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext50_32x4d_ssl')
# model = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext101_32x4d_ssl')
# model = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext101_32x8d_ssl')
# model = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext101_32x16d_ssl')
model.eval()
์ฌ์ ์ ํ์ต๋ ๋ชจ๋ ๋ชจ๋ธ์ ๋์ผํ ๋ฐฉ์์ผ๋ก ์ ๊ทํ๋ ์
๋ ฅ ์ด๋ฏธ์ง, ์ฆ, H
์ W
๋ ์ต์ 224
์ด์์ธ (3 x H x W)
ํํ์ 3-์ฑ๋ RGB ์ด๋ฏธ์ง์ ๋ฏธ๋ ๋ฐฐ์น๋ฅผ ์๊ตฌํฉ๋๋ค. ์ด๋ฏธ์ง๋ฅผ [0, 1]
๋ฒ์์์ ๋ถ๋ฌ์จ ๋ค์ mean = [0.485, 0.456, 0.406]
๊ณผ std = [0.229, 0.224, 0.225]
๋ฅผ ํตํด ์ ๊ทํํฉ๋๋ค.
์คํ ์์์ ๋๋ค.
# ํ์ดํ ์น ์น์ฌ์ดํธ์์ ์์ ์ด๋ฏธ์ง๋ฅผ ๋ค์ด๋ก๋ํฉ๋๋ค.
import urllib
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)
# ์คํ ์์์
๋๋ค. (torchvision ํ์)
from PIL import Image
from torchvision import transforms
input_image = Image.open(filename)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # ๋ชจ๋ธ์์ ์๊ตฌํ๋ ๋ฏธ๋๋ฐฐ์น๋ฅผ ์์ฑํฉ๋๋ค.
# ๊ฐ๋ฅํ๋ค๋ฉด ์๋๋ฅผ ์ํด ์
๋ ฅ๊ณผ ๋ชจ๋ธ์ GPU๋ก ์ฎ๊น๋๋ค.
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
model.to('cuda')
with torch.no_grad():
output = model(input_batch)
# 1000๊ฐ์ ImageNet ํด๋์ค์ ๋ํ ์ ๋ขฐ๋ ์ ์(confidence score)๋ฅผ ๊ฐ์ง 1000 ํฌ๊ธฐ์ Tensor
print(output[0])
# output์ ์ ๊ทํ๋์ง ์์ ์ ๋ขฐ๋ ์ ์๊ฐ ์์ต๋๋ค. ํ๋ฅ ๊ฐ์ ์ป์ผ๋ ค๋ฉด softmax๋ฅผ ์คํํ์ธ์.
print(torch.nn.functional.softmax(output[0], dim=0))
๋ณธ ๋ฌธ์์์ Billion-scale Semi-Supervised Learning for Image Classification์์ ์ ์๋ Semi-supervised, Semi-weakly supervised ๋ฐฉ์์ ImageNet ๋ถ๋ฅ ๋ชจ๋ธ์ ๋ค๋ฃน๋๋ค.
"Semi-supervised" ๋ฐฉ์์์ ๋์ฉ๋(hight-capacity)์ teacher ๋ชจ๋ธ์ ImageNet1K ํ๋ จ ๋ฐ์ดํฐ๋ก ํ์ต๋ฉ๋๋ค. student ๋ชจ๋ธ์ ๋ ์ด๋ธ์ด ์๋ YFCC100M์ ์ผ๋ถ ์ด๋ฏธ์ง๋ฅผ ํ์ฉํด ์ฌ์ ํ์ตํ๋ฉฐ, ์ดํ ImageNet1K์ ํ๋ จ ๋ฐ์ดํฐ๋ฅผ ํตํด์ ํ์ธ ํ๋ํฉ๋๋ค. ์์ธํ ์ฌํญ์ ์์ ์ธ๊ธํ ๋ ผ๋ฌธ์์ ํ์ธํ ์ ์์ต๋๋ค.
"Semi-weakly supervised" ๋ฐฉ์์์ teacher ๋ชจ๋ธ์ ํด์ํ๊ทธ๊ฐ ํฌํจ๋ 9์ต 4์ฒ๋ง์ฅ์ ์ด๋ฏธ์ง ์ผ๋ถ๋ฅผ ํ์ฉํด ์ฌ์ ํ์ต๋๋ฉฐ, ์ดํ ImageNet1K ํ๋ จ ๋ฐ์ดํฐ๋ก ํ์ธ ํ๋๋ฉ๋๋ค. ํ์ฉ๋ ํด์ํ๊ทธ๋ 1500๊ฐ ์ ๋์ด๋ฉฐ ImageNet1K ๋ ์ด๋ธ์ ๋์์ด ์งํฉ(synsets)๋ค์ ๋ชจ์ ๊ฒ์ ๋๋ค. ํด์ํ๊ทธ๋ teacher ๋ชจ๋ธ ์ฌ์ ํ์ต ๊ณผ์ ์์๋ง ๋ ์ด๋ธ๋ก ํ์ฉ๋ฉ๋๋ค. student ๋ชจ๋ธ์ teacher ๋ชจ๋ธ์ด ์ฌ์ฉํ ์ด๋ฏธ์ง์ ImageNet1k ๋ ์ด๋ธ๋ก ์ฌ์ ํ์ตํ๋ฉฐ, ์ดํ ImageNet1K์ ํ๋ จ ๋ฐ์ดํฐ๋ฅผ ํตํด์ ํ์ธ ํ๋ํฉ๋๋ค.
Xie et al., Mixup, LabelRefinery, Autoaugment, Weakly supervised ๊ธฐ๋ฒ์ ํ์ฉํ์ ๋์ ๋น๊ตํ์ ๋, Semi-supervised ๋ฐ Semi-weakly-supervised ๋ฐฉ์์ ResNet, ResNext ๋ชจ๋ธ์ ImageNet Top-1 ๊ฒ์ฆ ์ ํ๋๋ฅผ ํฌ๊ฒ ๊ฐ์ ํ์ต๋๋ค. ์์, ResNet-50 ๊ตฌ์กฐ๋ก ImageNet ๊ฒ์ฆ ์ ํ๋๋ฅผ 81.2% ๊ธฐ๋กํ์ต๋๋ค..
Architecture | Supervision | #Parameters | FLOPS | Top-1 Acc. | Top-5 Acc. |
---|---|---|---|---|---|
ResNet-18 | semi-supervised | 14M | 2B | 72.8 | 91.5 |
ResNet-50 | semi-supervised | 25M | 4B | 79.3 | 94.9 |
ResNeXt-50 32x4d | semi-supervised | 25M | 4B | 80.3 | 95.4 |
ResNeXt-101 32x4d | semi-supervised | 42M | 8B | 81.0 | 95.7 |
ResNeXt-101 32x8d | semi-supervised | 88M | 16B | 81.7 | 96.1 |
ResNeXt-101 32x16d | semi-supervised | 193M | 36B | 81.9 | 96.2 |
ResNet-18 | semi-weakly supervised | 14M | 2B | 73.4 | 91.9 |
ResNet-50 | semi-weakly supervised | 25M | 4B | 81.2 | 96.0 |
ResNeXt-50 32x4d | semi-weakly supervised | 25M | 4B | 82.2 | 96.3 |
ResNeXt-101 32x4d | semi-weakly supervised | 42M | 8B | 83.4 | 96.8 |
ResNeXt-101 32x8d | semi-weakly supervised | 88M | 16B | 84.3 | 97.2 |
ResNeXt-101 32x16d | semi-weakly supervised | 193M | 36B | 84.8 | 97.4 |
์ ์ฅ์์ ๊ณต๊ฐ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ ๋, ๋ค์ ๋ ผ๋ฌธ์ ์ธ์ฉํด์ฃผ์ธ์. (Billion-scale Semi-Supervised Learning for Image Classification)
@misc{yalniz2019billionscale,
title={Billion-scale semi-supervised learning for image classification},
author={I. Zeki Yalniz and Hervรฉ Jรฉgou and Kan Chen and Manohar Paluri and Dhruv Mahajan},
year={2019},
eprint={1905.00546},
archivePrefix={arXiv},
primaryClass={cs.CV}
}