PyTorch 는 특히 연구자들 사이에서 파이썬 기반 딥러닝을 위한 가장 인기 있는 프레임워크 중 하나입니다. W&B 는 그레이디언트 로그 기록부터 CPU 및 GPU 상의 코드 프로파일링에 이르기까지 PyTorch 에 대한 최상급 지원을 제공합니다.
또한 예제 레포지토리에서 Fashion MNIST 데이터셋에 Hyperband를 사용한 하이퍼파라미터 최적화 스크립트와, 이를 통해 생성된 W&B Dashboard를 확인할 수 있습니다.
run.watch를 사용하여 그레이디언트 로그 기록
그레이디언트를 자동으로 로그에 기록하려면, wandb.Run.watch()를 호출하고 PyTorch 모델을 전달하면 됩니다.
import wandb
with wandb.init(config=args) as run:
model = ... # 모델 설정
# 매직 라인
run.watch(model, log_freq=100)
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
run.log({"loss": loss})
동일한 스크립트에서 여러 모델을 추적해야 하는 경우, 각 모델에 대해 wandb.Run.watch()를 별도로 호출할 수 있습니다.
그레이디언트, 메트릭 및 그래프는 순전파(forward) 및 역전파(backward) 패스 이후에 wandb.Run.log()가 호출될 때까지 로그에 기록되지 않습니다.
이미지 및 미디어 로그 기록
이미지 데이터가 포함된 PyTorch Tensors를 wandb.Image에 전달할 수 있으며, torchvision 유틸리티가 이를 자동으로 이미지로 변환하는 데 사용됩니다.
with wandb.init(project="my_project", entity="my_entity") as run:
images_t = ... # PyTorch Tensor로 이미지를 생성하거나 로드
run.log({"examples": [wandb.Image(im) for im in images_t]})
PyTorch 및 기타 프레임워크에서 W&B 에 풍부한 미디어를 로그 기록하는 방법에 대한 자세한 내용은 미디어 로그 기록 가이드를 확인하세요.
미디어와 함께 모델의 예측값(prediction)이나 파생된 메트릭 같은 정보를 포함하고 싶다면 wandb.Table을 사용하세요.
with wandb.init() as run:
my_table = wandb.Table()
my_table.add_column("image", images_t)
my_table.add_column("label", labels)
my_table.add_column("class_prediction", predictions_t)
# W&B에 Table 로그 기록
run.log({"mnist_predictions": my_table})
데이터셋과 모델을 로그 기록하고 시각화하는 방법에 대한 자세한 내용은 W&B Tables 가이드를 확인하세요.
PyTorch 코드 프로파일링
W&B 는 PyTorch Kineto의 Tensorboard 플러그인과 직접 연동되어 PyTorch 코드를 프로파일링하고, CPU 와 GPU 통신의 세부 사항을 검사하며, 병목 현상 및 최적화 지점을 식별하는 툴을 제공합니다.
profile_dir = "path/to/run/tbprofile/"
profiler = torch.profiler.profile(
schedule=schedule, # 스케줄링에 대한 자세한 내용은 profiler 문서를 참조하세요
on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir),
with_stack=True,
)
with profiler:
... # 프로파일링하려는 코드를 여기서 실행하세요
# 자세한 사용법은 profiler 문서를 참조하세요
# wandb Artifact 생성
profile_art = wandb.Artifact("trace", type="profile")
# pt.trace.json 파일을 Artifact에 추가
profile_art.add_file(glob.glob(profile_dir + ".pt.trace.json"))
# artifact 저장
profile_art.save()
이 Colab에서 작동하는 예제 코드를 확인하고 실행해 보세요.
대화형 트레이스 뷰어 툴은 Chrome 브라우저에 최적화된 Chrome Trace Viewer 를 기반으로 합니다.