Skip to main content

watch

Hooks into the given PyTorch model(s) to monitor gradients and the model's computational graph.

watch(
models: (torch.nn.Module | Sequence[torch.nn.Module]),
criterion: (torch.F | None) = None,
log: (Literal['gradients', 'parameters', 'all'] | None) = "gradients",
log_freq: int = 1000,
idx: (int | None) = None,
log_graph: bool = (False)
) -> None

This function can track parameters, gradients, or both during training. It should be extended to support arbitrary machine learning models in the future.

Args
models (Union[torch.nn.Module, Sequence[torch.nn.Module]]): A single model or a sequence of models to be monitored. criterion (Optional[torch.F]): The loss function being optimized (optional). log (Optional[Literal["gradients", "parameters", "all"]]): Specifies whether to log "gradients", "parameters", or "all". Set to None to disable logging. (default="gradients") log_freq (int): Frequency (in batches) to log gradients and parameters. (default=1000) idx (Optional[int]): Index used when tracking multiple models with wandb.watch. (default=None) log_graph (bool): Whether to log the model's computational graph. (default=False)
Raises
ValueErrorIf wandb.init has not been called or if any of the models are not instances of torch.nn.Module.
Was this page helpful?๐Ÿ‘๐Ÿ‘Ž