wandb.watch

View source on GitHub

Hooks into the torch model to collect gradients and the topology.

watch(
models, criterion=None, log='gradients', log_freq=1000, idx=None
)

Should be extended to accept arbitrary ML models.

Args

models

(torch.Module) The model to hook, can be a tuple

criterion

(torch.F) An optional loss value being optimized

log

(str) One of "gradients", "parameters", "all", or None

log_freq

(int) log gradients and parameters every N batches

idx

(int) an index to be used when calling wandb.watch on multiple models

Returns

wandb.Graph The graph object that will populate after the first backward pass