새소식

ML Development/ML Framework - 2022.07.15

Pytorch hub를 통해 MEAL V2 모델 사용하기

  • -

Pytorch hub

pytorch hub는 torch 기반 모델 publish 및 load를 지원하여 사전 학습된 모델을 손쉽게 활용할 수 있도록 도와준다.

현재 load 해서 사용할 수 있는 모델은 48개라고 돼있지만 pytorch-transformers와 같이 여러 모델들을 한 곳에 모아둔 경우도 있어 실제로는 더 다양한 모델을 사용할 수 있다.

 

publish 및 load에 대한 자세한 방법은 아래 링크를 통해서 확인할 수 있다.

 

torch.hub — PyTorch 1.12 documentation

torch.hub Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility. Publishing models Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights) to a github repository by adding a simp

pytorch.org

hub에는 다양한 모델이 있는데 이 중에서 MEAL(Multi-Model Ensemble via Adversarial Learning) V2를 사용해보면서 간단하게 사용법을 익혀보자.

 

MEAL V2 ?

V2라는 호칭에서 알 수 있듯이 기존 MEAL을 개선한 모델이다. 참고로 V1, V2 모두 같은 사람이 만들었다. MEAL에 대한 설명은 포스팅을 통해서 확인할 수 있다.

MEAL V2에선 Knoweldge Distillation과 classification problem의 관계에 대해서 더 고민해보며 MEAL의 구조를 조금 수정했다. 큰 차이점은 다음과 같다.

(1) adopting the similarity loss and discriminator only on the final outputs and (2) using the average of softmax probabilities from all teacher ensembles as the stronger supervision.

V1 - V2 차이점

MEAL V2

파이토치 허브를 사용하면, 이러한 MEAL V2를 쉽게 활용할 수 있다. 링크를 들어간 뒤 적혀있는 코드를 따라서 치기만 하면 간단하게 사용할 수 있다. 만약 귀찮다면 이 포스팅의 남은 부분을 확인하자.   

 

MEAL은 사실 모델이라기보다는 Knowledge Distillation을 이용한 ensemble 방법을 다루고 있다. 따라서 timm을 통해서 활용된 resnet 모델을 불러온다. 

!pip install timm
import torch
# list of models: 'mealv1_resnest50', 'mealv2_resnest50', 'mealv2_resnest50_cutmix', 'mealv2_resnest50_380x380', 'mealv2_mobilenetv3_small_075', 'mealv2_mobilenetv3_small_100', 'mealv2_mobilenet_v3_large_100', 'mealv2_efficientnet_b0'
# load pretrained models, using "mealv2_resnest50_cutmix" as an example
model = torch.hub.load('szq0214/MEAL-V2','meal_v2', 'mealv2_resnest50_cutmix', pretrained=True)
model.eval() # inference만 진행

궁금한 개

예측을 잘하는지 확인하기 위해서 개 사진을 다운로드한다. 사진은 위와 같다.

# Download an example image from the pytorch website
import urllib
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)

 transforms을 통해 기본적인 전처리 후 모델에 입력해 logits과 확률 값을 확인해 본다.

# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
input_image = Image.open(filename)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 정규화 진행
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # batch로 다루기 위해서 unsqueeze(0) 진행

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities)

어떤 클래스에 속하는지 알 수 없으니 ImageNet 클래스 정보를 받아온다.

# Download ImageNet labels
!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt

각 클래스에 속할 확률 상위 값을 찍어보면 다음과 같다. 모델은 사모예드일 가능성이 가장 높다고 예측했다.

다른 사모예드들과 비슷하게 생긴 걸 보니 잘 예측한 거 같다.

# Read the categories
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]
# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
    print(categories[top5_catid[i]], top5_prob[i].item())
Samoyed 0.7010934352874756
Pomeranian 0.08209742605686188
white wolf 0.0401899553835392
Arctic fox 0.013811168260872364
Eskimo dog 0.011597750708460808

궁금한 개
사모 예드


참고

 

PyTorch

An open source machine learning framework that accelerates the path from research prototyping to production deployment.

pytorch.org

MEAL V2: Boosting Vanilla ResNet-50 to 80%+ Top-1 Accuracy on ImageNet without Tricks

Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.