wandb.watch() hooks into a PyTorch model’s parameters and gradients and logs histograms of their values at regular intervals. This is useful for diagnosing training instability, vanishing gradients, and dead neurons.
Basic usage
Call wandb.watch() after wandb.init() and before the first training step:
log_freq batches (the Run.watch() default is log_freq=1000; the example uses 100 for faster feedback). They appear in the Charts tab under keys like gradients/layer_name.weight.
log parameter options
| Value | What is logged |
|---|---|
"gradients" | Gradient histograms only (default) |
"parameters" | Weight/parameter histograms only |
"all" | Both gradients and parameters |
None | Neither — only logs model graph topology |
log_graph=True when you want the computational graph while histogram logging is off or minimal. View the graph in the run’s Overview tab under Model. See Run.watch() for how log, log_graph, and log_freq interact.
wandb.watch() separately for each model (useful in GAN training):
log_freq. Logging every step (log_freq=1) can significantly slow training. A value between 50 and 200 is typical for most training runs. If performance is critical, set log="parameters" rather than log="gradients" — parameter histograms are computed without a backward pass hook and are cheaper.
Stopping the watch
To stop logging gradients mid-training:
Experiments Metrics Runs