메인 콘텐츠로 건너뛰기
Try in Colab PyTorch Lightning을 사용하여 이미지 분류 파이프라인을 구축해 보겠습니다. 코드의 가독성과 재현성을 높이기 위해 이 스타일 가이드를 따를 것입니다. 이에 대한 멋진 설명은 여기에서 확인하실 수 있습니다.

PyTorch Lightning 및 W&B 설정

이 튜토리얼을 위해서는 PyTorch Lightning과 W&B가 필요합니다.
pip install lightning -q
pip install wandb -qU
import lightning.pytorch as pl

# 즐겨 사용하는 기계학습 트래킹 툴
from lightning.pytorch.loggers import WandbLogger

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoader

from torchmetrics import Accuracy

from torchvision import transforms
from torchvision.datasets import CIFAR10

import wandb
이제 wandb 계정에 로그인해야 합니다.
wandb.login()

DataModule - 우리가 원하던 데이터 파이프라인

DataModule은 데이터 관련 훅을 LightningModule에서 분리하는 방법으로, 데이터셋에 구애받지 않는 모델을 개발할 수 있게 해줍니다. 데이터 파이프라인을 공유 가능하고 재사용 가능한 하나의 클래스로 구성합니다. DataModule은 PyTorch의 데이터 처리와 관련된 5단계를 캡슐화합니다:
  • 다운로드 / 토큰화 / 처리하다.
  • 정제 및 (필요시) 디스크에 저장.
  • 데이터셋 내부에 로드.
  • 변환 적용 (회전, 토큰화 등…).
  • DataLoader 내부에 래핑.
DataModule에 대해 자세히 알아보려면 여기를 참조하세요. CIFAR-10 데이터셋을 위한 DataModule을 만들어 보겠습니다.
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, batch_size, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        self.num_classes = 10
    
    def prepare_data(self):
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)
    
    def setup(self, stage=None):
        # 데이터로더에서 사용할 트레이닝/검증 데이터셋 할당
        if stage == 'fit' or stage is None:
            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])

        # 데이터로더에서 사용할 테스트 데이터셋 할당
        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)
    
    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=self.batch_size)

콜백 (Callbacks)

콜백은 프로젝트 간에 재사용할 수 있는 독립적인 프로그램입니다. PyTorch Lightning은 자주 사용되는 몇 가지 내장 콜백을 제공합니다. PyTorch Lightning의 콜백에 대해 자세히 알아보려면 여기를 참조하세요.

내장 콜백 (Built-in Callbacks)

이 튜토리얼에서는 Early StoppingModel Checkpoint 내장 콜백을 사용합니다. 이들은 Trainer에 전달될 수 있습니다.

커스텀 콜백 (Custom Callbacks)

Keras의 커스텀 콜백에 익숙하다면, PyTorch 파이프라인에서도 동일한 기능을 수행할 수 있다는 점이 매우 반가울 것입니다. 이미지 분류를 수행하고 있으므로, 일부 이미지 샘플에 대한 모델의 예측값을 시각화하는 기능이 도움이 될 수 있습니다. 이를 콜백 형태로 구현하면 초기 단계에서 모델을 디버깅하는 데 도움이 됩니다.
class ImagePredictionLogger(pl.callbacks.Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples
    
    def on_validation_epoch_end(self, trainer, pl_module):
        # 텐서를 CPU로 이동
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)
        # 모델 예측값 가져오기
        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)
        # 이미지를 wandb Image로 로그 기록
        trainer.logger.experiment.log({
            "examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") 
                           for x, pred, y in zip(val_imgs[:self.num_samples], 
                                                 preds[:self.num_samples], 
                                                 val_labels[:self.num_samples])]
            })
        

LightningModule - 시스템 정의

LightningModule은 모델이 아닌 시스템을 정의합니다. 여기서 시스템은 모든 연구 코드를 하나의 클래스로 그룹화하여 독립적으로 만드는 것을 의미합니다. LightningModule은 PyTorch 코드를 5개의 섹션으로 정리합니다:
  • 계산 (__init__).
  • 트레이닝 루프 (training_step)
  • 검증 루프 (validation_step)
  • 테스트 루프 (test_step)
  • 옵티마이저 (configure_optimizers)
이를 통해 쉽게 공유할 수 있는 데이터셋 독립적인 모델을 구축할 수 있습니다. CIFAR-10 분류를 위한 시스템을 만들어 보겠습니다.
class LitModel(pl.LightningModule):
    def __init__(self, input_shape, num_classes, learning_rate=2e-4):
        super().__init__()
        
        # 하이퍼파라미터 로그 기록
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 32, 3, 1)
        self.conv3 = nn.Conv2d(32, 64, 3, 1)
        self.conv4 = nn.Conv2d(64, 64, 3, 1)

        self.pool1 = torch.nn.MaxPool2d(2)
        self.pool2 = torch.nn.MaxPool2d(2)
        
        n_sizes = self._get_conv_output(input_shape)

        self.fc1 = nn.Linear(n_sizes, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, num_classes)

        self.accuracy = Accuracy(task='multiclass', num_classes=num_classes)

    # conv 블록에서 Linear 레이어로 들어가는 출력 텐서의 크기를 반환합니다.
    def _get_conv_output(self, shape):
        batch_size = 1
        input = torch.autograd.Variable(torch.rand(batch_size, *shape))

        output_feat = self._forward_features(input) 
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size
        
    # conv 블록에서 특징 텐서를 반환합니다.
    def _forward_features(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool2(F.relu(self.conv4(x)))
        return x
    
    # 추론 시에 사용됩니다.
    def forward(self, x):
       x = self._forward_features(x)
       x = x.view(x.size(0), -1)
       x = F.relu(self.fc1(x))
       x = F.relu(self.fc2(x))
       x = F.log_softmax(self.fc3(x), dim=1)
       
       return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # 트레이닝 메트릭
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # 검증 메트릭
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # 검증 메트릭
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

트레이닝 및 평가

이제 DataModule을 사용하여 데이터 파이프라인을 구성하고 LightningModule을 사용하여 모델 아키텍처 및 트레이닝 루프를 구성했으므로, PyTorch Lightning의 Trainer가 나머지 모든 것을 자동화해 줍니다. Trainer가 자동화하는 항목:
  • 에포크 및 배치 반복
  • optimizer.step(), backward, zero_grad() 호출
  • .eval() 호출, 그레이디언트 활성화/비활성화
  • 가중치 저장 및 로드
  • W&B 로깅
  • Multi-GPU 트레이닝 지원
  • TPU 지원
  • 16-bit 트레이닝 지원
dm = CIFAR10DataModule(batch_size=32)
# 데이터로더에 엑세스하려면 prepare_data와 setup을 호출해야 합니다.
dm.prepare_data()
dm.setup()

# 커스텀 ImagePredictionLogger 콜백에서 이미지 예측을 기록하는 데 필요한 샘플입니다.
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape
model = LitModel((3, 32, 32), dm.num_classes)

# wandb logger 초기화
wandb_logger = WandbLogger(project='wandb-lightning', job_type='train')

# 콜백 초기화
early_stop_callback = pl.callbacks.EarlyStopping(monitor="val_loss")
checkpoint_callback = pl.callbacks.ModelCheckpoint()

# trainer 초기화
trainer = pl.Trainer(max_epochs=2,
                     logger=wandb_logger,
                     callbacks=[early_stop_callback,
                                ImagePredictionLogger(val_samples),
                                checkpoint_callback],
                     )

# 모델 트레이닝 
trainer.fit(model, dm)

# 별도의 테스트 세트에서 모델 평가 ⚡⚡
trainer.test(dataloaders=dm.test_dataloader())

# wandb run 종료
run.finish()

마무리 소감

저는 TensorFlow/Keras 에코시스템 출신이라 PyTorch가 우아한 프레임워크임에도 불구하고 다소 어렵게 느껴졌습니다. 개인적인 경험일 뿐이지만요. PyTorch Lightning을 살펴보면서, 제가 PyTorch를 멀리하게 했던 거의 모든 이유가 해결되었다는 것을 깨달았습니다. 제가 느낀 흥분되는 점들을 요약하자면 다음과 같습니다:
  • 이전: 기존의 PyTorch 모델 정의는 여기저기 흩어져 있곤 했습니다. 모델은 특정 model.py 스크립트에 있고 트레이닝 루프는 train.py 파일에 있는 식이죠. 파이프라인을 이해하기 위해 앞뒤로 계속 확인해야 했습니다.
  • 현재: LightningModule은 모델이 training_step, validation_step 등과 함께 정의되는 시스템 역할을 합니다. 이제 모듈화되어 있고 공유가 가능합니다.
  • 이전: TensorFlow/Keras의 가장 좋은 점은 입력 데이터 파이프라인입니다. 데이터셋 카탈로그가 풍부하고 계속 늘어나고 있죠. PyTorch의 데이터 파이프라인은 가장 큰 고충이었습니다. 일반적인 PyTorch 코드에서는 데이터 다운로드/정제/준비 과정이 보통 여러 파일에 흩어져 있습니다.
  • 현재: DataModule은 데이터 파이프라인을 공유 가능하고 재사용 가능한 하나의 클래스로 정리합니다. 이는 단순히 train_dataloader, val_dataloader(s), test_dataloader(s)와 그에 맞는 변환 및 필요한 데이터 처리/다운로드 단계의 집합입니다.
  • 이전: Keras에서는 model.fit으로 모델을 트레이닝하고 model.predict로 추론을 실행할 수 있습니다. model.evaluate는 테스트 데이터에 대한 간편한 평가를 제공했죠. PyTorch는 그렇지 않았습니다. 보통 별도의 train.pytest.py 파일을 보게 됩니다.
  • 현재: LightningModule이 있으면 Trainer가 모든 것을 자동화합니다. 모델을 트레이닝하고 평가하기 위해 trainer.fittrainer.test를 호출하기만 하면 됩니다.
  • 이전: TensorFlow는 TPU를 사랑하고, PyTorch는…
  • 현재: PyTorch Lightning을 사용하면 동일한 모델을 여러 GPU나 심지어 TPU에서도 매우 쉽게 트레이닝할 수 있습니다.
  • 이전: 저는 콜백을 매우 좋아하고 커스텀 콜백 작성을 선호합니다. Early Stopping처럼 사소한 것조차 기존 PyTorch에서는 논의의 대상이 되곤 했습니다.
  • 현재: PyTorch Lightning에서는 Early Stopping과 Model Checkpointing을 사용하는 것이 매우 쉽습니다. 심지어 커스텀 콜백도 작성할 수 있습니다.

🎨 결론 및 리소스

이 리포트가 도움이 되었기를 바랍니다. 코드를 직접 실행해보고 원하는 데이터셋으로 이미지 분류기를 트레이닝해 보시길 권장합니다. PyTorch Lightning에 대해 더 자세히 알아볼 수 있는 리소스는 다음과 같습니다:
  • 단계별 가이드: 공식 튜토리얼 중 하나입니다. 문서가 매우 잘 작성되어 있어 훌륭한 학습 리소스로 적극 추천합니다.
  • W&B와 함께 PyTorch Lightning 사용하기: PyTorch Lightning과 함께 W&B를 사용하는 방법을 더 자세히 배울 수 있는 빠른 colab입니다.