Skip to content

Latest commit

Β 

History

History
113 lines (97 loc) Β· 4.55 KB

pytorch_vision_vgg.md

File metadata and controls

113 lines (97 loc) Β· 4.55 KB
layout background-class body-class title summary category image author tags github-link github-id featured_image_1 featured_image_2 accelerator order demo-model-link
hub_detail
hub-background
hub
vgg-nets
Award winning ConvNets from 2014 Imagenet ILSVRC challenge
researchers
vgg.png
Pytorch Team
vision
scriptable
pytorch/vision
vgg.png
no-image
cuda-optional
10
import torch
model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg11', pretrained=True)
# μΆ”κ°€λ‘œ μ•„λž˜μ™€ 같이 λ³€ν˜•λœ ꡬ쑰의 λͺ¨λΈλ“€μ΄ μžˆμŠ΅λ‹ˆλ‹€
# model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg11_bn', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg13', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg13_bn', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg16', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg16_bn', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg19', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg19_bn', pretrained=True)
model.eval()

λͺ¨λ“  사전 ν›ˆλ ¨λœ λͺ¨λΈμ€ ν›ˆλ ¨λ•Œμ™€ 같은 λ°©μ‹μœΌλ‘œ μ •κ·œν™”λœ μž…λ ₯ 이미지λ₯Ό μ£Όμ–΄μ•Όν•©λ‹ˆλ‹€. 즉, (3 x H x W) λͺ¨μ–‘μ˜ 3채널 RGB μ΄λ―Έμ§€μ˜ λ―Έλ‹ˆλ°°μΉ˜μ—μ„œ H와 WλŠ” μ΅œμ†Œ 224κ°€ 될 κ²ƒμœΌλ‘œ μ˜ˆμƒλ©λ‹ˆλ‹€. μ΄λ―Έμ§€λŠ” [0, 1] λ²”μœ„λ‘œ λ‘œλ“œν•œ λ‹€μŒ(RGB μ±„λ„λ§ˆλ‹€ 0~255κ°’μœΌλ‘œ ν‘œν˜„λ˜λ―€λ‘œ 이미지λ₯Ό 255둜 λ‚˜λˆ”) mean = [0.485, 0.456, 0.406]κ³Ό std = [0.229, 0.224, 0.225] 값을 μ‚¬μš©ν•˜μ—¬ μ •κ·œν™”ν•΄μ•Ό ν•©λ‹ˆλ‹€.

λ‹€μŒμ€ μƒ˜ν”Œ μ‹€ν–‰μž…λ‹ˆλ‹€.

# νŒŒμ΄ν† μΉ˜ μ›Ήμ‚¬μ΄νŠΈμ—μ„œ 예제 이미지λ₯Ό λ‹€μš΄λ‘œλ“œ ν•©λ‹ˆλ‹€
import urllib
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)
# μƒ˜ν”Œ μ‹€ν–‰ (torchvision이 ν•„μš”ν•©λ‹ˆλ‹€)
from PIL import Image
from torchvision import transforms
input_image = Image.open(filename)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # λͺ¨λΈμ˜ μž…λ ₯값에 맞좘 λ―Έλ‹ˆ 배치 생성

# κ°€λŠ₯ν•˜λ©΄ 속도λ₯Ό μœ„ν•΄μ„œ μž…λ ₯κ³Ό λͺ¨λΈμ„ GPU둜 이동 ν•©λ‹ˆλ‹€
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)
# Imagenet의 1000개 ν΄λž˜μŠ€μ— λŒ€ν•œ 신뒰도 μ μˆ˜κ°€ μžˆλŠ” 1000개의 Tensorμž…λ‹ˆλ‹€.
print(output[0])
# 좜λ ₯에 μ •κ·œν™”λ˜μ§€ μ•Šμ€ μ μˆ˜κ°€ μžˆμŠ΅λ‹ˆλ‹€. ν™•λ₯ μ„ μ–»μœΌλ €λ©΄ μ†Œν”„νŠΈλ§₯슀λ₯Ό μ‹€ν–‰ν•  수 μžˆμŠ΅λ‹ˆλ‹€.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities)
# ImageNet 라벨 λ‹€μš΄λ‘œλ“œ
!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
# μΉ΄ν…Œκ³ λ¦¬ 읽기
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]
# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
    print(categories[top5_catid[i]], top5_prob[i].item())

λͺ¨λΈ μ„€λͺ…

각 ꡬ성 및 bachnorm 버전에 λŒ€ν•΄μ„œ Very Deep Convolutional Networks for Large-Scale Image Recognitionμ—μ„œ μ œμ•ˆν•œ λͺ¨λΈμ— λŒ€ν•œ κ΅¬ν˜„μ΄ μžˆμŠ΅λ‹ˆλ‹€.

예λ₯Ό λ“€μ–΄, 논문에 μ œμ‹œλœ ꡬ성 AλŠ” vgg11, BλŠ” vgg13, DλŠ” vgg16, EλŠ” vgg19μž…λ‹ˆλ‹€. batchnorm 버전은 _bn이 μ ‘λ―Έμ‚¬λ‘œ λΆ™μ–΄μžˆμŠ΅λ‹ˆλ‹€.

사전 ν›ˆλ ¨λœ λͺ¨λΈμ΄ μžˆλŠ” imagenet 데이터 μ„ΈνŠΈμ˜ 1-crop 였λ₯˜μœ¨μ€ μ•„λž˜μ— λ‚˜μ—΄λ˜μ–΄ μžˆμŠ΅λ‹ˆλ‹€.

Model structure Top-1 error Top-5 error
vgg11 30.98 11.37
vgg11_bn 26.70 8.58
vgg13 30.07 10.75
vgg13_bn 28.45 9.63
vgg16 28.41 9.62
vgg16_bn 26.63 8.50
vgg19 27.62 9.12
vgg19_bn 25.76 8.15

μ°Έμ‘°