Skip to main content

watch

Hook into the torch model to collect gradients and the topology.

watch(
models,
criterion=None,
log: Optional[Literal['gradients', 'parameters', 'all']] = "gradients",
log_freq: int = 1000,
idx: Optional[int] = None,
log_graph: bool = (False)
)

Should be extended to accept arbitrary ML models.

Args
models(torch.Module) The model to hook, can be a tuple
criterion(torch.F) An optional loss value being optimized
log(str) One of "gradients", "parameters", "all", or None
log_freq(int) log gradients and parameters every N batches
idx(int) an index to be used when calling wandb.watch on multiple models
log_graph(boolean) log graph topology
Returns
wandb.Graph: The graph object that will populate after the first backward pass
Raises
ValueErrorIf called before wandb.init or if any of models is not a torch.nn.Module.
Was this page helpful?๐Ÿ‘๐Ÿ‘Ž