Skip to main content
wandb.watch() hooks into a PyTorch model’s parameters and gradients and logs histograms of their values at regular intervals. This is useful for diagnosing training instability, vanishing gradients, and dead neurons. Basic usage Call wandb.watch() after wandb.init() and before the first training step:
import wandb
import torch.nn as nn

wandb.init(project="my-project")

model = MyModel()
wandb.watch(model, log="gradients", log_freq=100)

for step, batch in enumerate(dataloader):
    loss = train_step(model, batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    wandb.log({"train/loss": loss.item()}, step=step)

wandb.finish()
Gradient histograms are logged every log_freq batches (the Run.watch() default is log_freq=1000; the example uses 100 for faster feedback). They appear in the Charts tab under keys like gradients/layer_name.weight. log parameter options
ValueWhat is logged
"gradients"Gradient histograms only (default)
"parameters"Weight/parameter histograms only
"all"Both gradients and parameters
NoneNeither — only logs model graph topology
wandb.watch(model, log="all", log_freq=50)
Logging model graph topology Pass log_graph=True when you want the computational graph while histogram logging is off or minimal. View the graph in the run’s Overview tab under Model. See Run.watch() for how log, log_graph, and log_freq interact.
wandb.watch(model, log=None, log_graph=True)  # graph focus, no histograms
Watching multiple models Call wandb.watch() separately for each model (useful in GAN training):
wandb.watch(generator, log="gradients", log_freq=100)
wandb.watch(discriminator, log="gradients", log_freq=100)
Each model’s gradients are logged with its parameter names as prefixes. Performance considerations Gradient logging adds overhead proportional to log_freq. Logging every step (log_freq=1) can significantly slow training. A value between 50 and 200 is typical for most training runs. If performance is critical, set log="parameters" rather than log="gradients" — parameter histograms are computed without a backward pass hook and are cheaper. Stopping the watch To stop logging gradients mid-training:
wandb.unwatch(model)
This removes the hooks without ending the run, so metric logging continues unaffected.
Experiments Metrics Runs