Skip to main content
wandb.watch()는 PyTorch 모델의 파라미터와 그라디언트에 훅을 연결하고, 일정한 간격으로 해당 값의 히스토그램을 로깅합니다. 이는 트레이닝 불안정성, 그라디언트 소실, 그리고 죽은 뉴런을 진단하는 데 유용합니다. 기본 사용법 첫 번째 트레이닝 step 전에, wandb.init()wandb.watch()를 호출하세요:
import wandb
import torch.nn as nn

wandb.init(project="my-project")

model = MyModel()
wandb.watch(model, log="gradients", log_freq=100)

for step, batch in enumerate(dataloader):
    loss = train_step(model, batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    wandb.log({"train/loss": loss.item()}, step=step)

wandb.finish()
그라디언트 히스토그램은 log_freq 배치마다 로깅됩니다(Run.watch()의 기본값은 log_freq=1000이며, 예시에서는 더 빠른 피드백을 위해 100을 사용합니다). 히스토그램은 gradients/layer_name.weight와 같은 키로 Charts 탭에 표시됩니다. log 매개변수 옵션
Value로깅되는 항목
"gradients"그라디언트 히스토그램만(기본값)
"parameters"가중치/매개변수 히스토그램만
"all"그라디언트와 매개변수 모두
None둘 다 로깅하지 않음 — 모델 그래프 토폴로지만 로깅
wandb.watch(model, log="all", log_freq=50)
모델 그래프 토폴로지 로깅 히스토그램 로깅이 꺼져 있거나 최소한으로만 수행되는 경우에도 계산 그래프를 보려면 log_graph=True를 전달하세요. 그래프는 run의 Overview 탭에 있는 Model 아래에서 확인할 수 있습니다. log, log_graph, log_freq가 어떻게 상호작용하는지는 Run.watch()에서 확인하세요.
wandb.watch(model, log=None, log_graph=True)  # 그래프 중심, 히스토그램 없음
여러 모델 모니터링하기 각 모델마다 wandb.watch()를 별도로 호출하세요(GAN 트레이닝에 유용합니다):
wandb.watch(generator, log="gradients", log_freq=100)
wandb.watch(discriminator, log="gradients", log_freq=100)
각 모델의 그라디언트는 해당 파라미터 이름을 접두사로 붙여 로깅됩니다. 성능 고려 사항 그라디언트 로깅은 log_freq에 비례하는 오버헤드를 추가합니다. 모든 step마다 로깅하면(log_freq=1) 트레이닝 속도가 크게 느려질 수 있습니다. 대부분의 트레이닝 작업에서는 50~200 사이의 값을 사용합니다. 성능이 중요하다면 log="gradients" 대신 log="parameters"로 설정하세요 — 파라미터 히스토그램은 역전파 훅 없이 계산되므로 비용이 더 적게 듭니다. watch 중지하기 트레이닝 도중 그라디언트 로깅을 중지하려면:
wandb.unwatch(model)
이렇게 하면 run을 종료하지 않고 훅만 제거하므로 metric 로깅은 영향 없이 계속됩니다.
Experiments 메트릭 Runs