소스코드 원본 링크 : https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

import packages and modules

# License: BSD
# Author: Sasank Chilamkurthy

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

plt.ion()   # interactive mode

데이터 로드 및 선처리

# Data augmentation and normalization for training
# Just normalization for validation
# torchvision은 데이터 set을 쉽게 가져오기 위해 사용한다.
# transforms는 데이터 선처리를 위해 사용한다.

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# train과 val에 대한 데이터를 딕셔너리 형태로 저장한다.
data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

이미지 시각화

def imshow(inp, title=None):
    """Imshow for Tensor."""
    # 이미지 시각화를 위해 inp를 텐서가 아닌 numpy로 받아서 사용해야한다.
    # inp를 numpy로 바꾸고 transpose를 할 때 channel,height,width의 자리를 고려하여 파라미터를 설정한다.
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    # inp는 데이터 선처리에서 normalize 되었으므로 다시 unnormalize를 해준다.
    inp = std * inp + mean
    # .clip은 다음 링크를 참조하여 직관적으로 받아들이는 것이 좋다.
    # https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.clip.html
    
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])

모델 학습

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

	# 처음엔 state_dict()가 뭔지 몰랐는데 출력해보니 weight에 대한 정보가 담겨있었다.
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

					# 역전파와 optimize는 학습시킬때만 동시에 사용한다.
                    # backward + optimize only if in training phase
                    # 역전파와 optimize는 학습시킬때만 동시에 사용한다.
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

모델 예측 시각화

def visualize_model(model, num_images=6):
    was_training = model.training
    # eval()은 다음과 같다.
    #`model.eval()`  will notify all your layers 
    # that you are in eval mode, that way, batchnorm or dropout layers will work 
    # in eval model instead of training mode.
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

FINETUNING THE CONVNET

model_ft = models.resnet18(pretrained=True)
# in_feature is the number of inputs for your linear layer
# 따라서 fc.in_features는 fully connected layer의 input 개수이다.
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

학습 및 평가

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)
visualize_model(model_ft)           

## CONVNET AS FIXED FEATURE EXTRACTOR

model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

학습 및 평가

model_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=25)
 visualize_model(model_conv)

plt.ioff()
plt.show()

transfer learning 코드를 실행하기 이전에 25번의 epoch를 돌리는데 40분이 넘게 걸렸다. AWS를 이용하여 GPU를 사용한다면 빠른 결과를 얻어낼 수 있겠지만 이러한 데이터 set보다 훨씬 많은 양의 데이터를 다뤄야 하고 증가(augmentation) 시킨다면 처음부터 다시 학습시키는 것은 굉장히 번거로울 것이다. 미리 학습된 모델과 추가할 데이터를 고려하여 마지막 레이어만 수정하는 것이 바로 transfer learning이다. 기존의 데이터와 추가될 데이터를 고려하여 어떤 레이어를 얼마만큼 수정해야 할지가 결정된다고 한다.