Skip to main content
wandb.watch() は PyTorch モデルのパラメーターと勾配にフックして、それらの値のヒストグラムを一定間隔でログします。これは、トレーニングの不安定性、勾配消失、ニューロンの機能停止の診断に役立ちます。 基本的な使い方 wandb.init() の後、最初のトレーニングステップの前に wandb.watch() を呼び出します。
import wandb
import torch.nn as nn

wandb.init(project="my-project")

model = MyModel()
wandb.watch(model, log="gradients", log_freq=100)

for step, batch in enumerate(dataloader):
    loss = train_step(model, batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    wandb.log({"train/loss": loss.item()}, step=step)

wandb.finish()
勾配ヒストグラムは log_freq バッチごとにログされます (Run.watch() のデフォルトは log_freq=1000 で、より速くフィードバックを得るため、この例では 100 を使用しています) 。これらは Charts タブに、gradients/layer_name.weight のようなキー名で表示されます。 log パラメーターのオプション
Valueログされる内容
"gradients"勾配ヒストグラムのみ (デフォルト)
"parameters"重み / パラメーターのヒストグラムのみ
"all"勾配とパラメーターの両方
Noneどちらもログされず、モデルグラフのトポロジのみ
wandb.watch(model, log="all", log_freq=50)
モデルグラフのトポロジをログする ヒストグラムのログをオフにしている場合や最小限に抑えている場合でも計算グラフを取得したいときは、log_graph=True を渡してください。グラフは、run の Overview タブの Model セクションで表示できます。loglog_graphlog_freq がどのように連動するかについては、Run.watch() を参照してください。
wandb.watch(model, log=None, log_graph=True)  # グラフに注目、ヒストグラムなし
複数のモデルを監視する 各モデルごとに wandb.watch() を個別に呼び出します (GAN のトレーニング時に便利です) :
wandb.watch(generator, log="gradients", log_freq=100)
wandb.watch(discriminator, log="gradients", log_freq=100)
各モデルの勾配は、パラメーター名を接頭辞としてログされます。 パフォーマンスに関する考慮事項 勾配のログ記録には、log_freq に比例したオーバーヘッドが発生します。すべてのステップでログを記録すると (log_freq=1) 、トレーニングが大幅に遅くなる可能性があります。ほとんどのトレーニング run では、50 から 200 の値が一般的です。パフォーマンスが重要な場合は、log="gradients" ではなく log="parameters" を設定してください。パラメーターのヒストグラムは backward pass の hook を使わずに計算されるため、より低コストです。 watch の停止 トレーニングの途中で勾配のログ記録を停止するには:
wandb.unwatch(model)
これにより、run を終了せずにフックを削除できるため、メトリクスのロギングはそのまま継続されます。
Experiments メトリクス Runs