メインコンテンツまでスキップ

watch

torchモデルにフックして、勾配とトポロジーを収集します。

watch(
models,
criterion=None,
log: Optional[Literal['gradients', 'parameters', 'all']] = "gradients",
log_freq: int = 1000,
idx: Optional[int] = None,
log_graph: bool = (False)
)

任意のMLモデルを受け入れるように拡張する必要があります。

引数
models(torch.Module) フックするモデル、タプルも可能
criterion(torch.F) オプションの最適化対象の損失値
log(str) "gradients", "parameters", "all", またはNoneのいずれか
log_freq(int) 勾配とパラメータをNバッチごとにログする
idx(int) 複数のモデルに対してwandb.watchを呼び出す際に使用するインデックス
log_graph(boolean) グラフトポロジーをログする
戻り値
wandb.Graph: 最初のbackward pass後にポピュレートされるグラフオブジェクト
例外
ValueErrorwandb.initの前に呼び出された場合、またはいずれかのモデルがtorch.nn.Moduleでない場合
Was this page helpful?👍👎