wandb.watch() は PyTorch モデルのパラメーターと勾配にフックして、それらの値のヒストグラムを一定間隔でログします。これは、トレーニングの不安定性、勾配消失、ニューロンの機能停止の診断に役立ちます。
基本的な使い方
wandb.init() の後、最初のトレーニングステップの前に wandb.watch() を呼び出します。
log_freq バッチごとにログされます (Run.watch() のデフォルトは log_freq=1000 で、より速くフィードバックを得るため、この例では 100 を使用しています) 。これらは Charts タブに、gradients/layer_name.weight のようなキー名で表示されます。
log パラメーターのオプション
| Value | ログされる内容 |
|---|---|
"gradients" | 勾配ヒストグラムのみ (デフォルト) |
"parameters" | 重み / パラメーターのヒストグラムのみ |
"all" | 勾配とパラメーターの両方 |
None | どちらもログされず、モデルグラフのトポロジのみ |
log_graph=True を渡してください。グラフは、run の Overview タブの Model セクションで表示できます。log、log_graph、log_freq がどのように連動するかについては、Run.watch() を参照してください。
wandb.watch() を個別に呼び出します (GAN のトレーニング時に便利です) :
log_freq に比例したオーバーヘッドが発生します。すべてのステップでログを記録すると (log_freq=1) 、トレーニングが大幅に遅くなる可能性があります。ほとんどのトレーニング run では、50 から 200 の値が一般的です。パフォーマンスが重要な場合は、log="gradients" ではなく log="parameters" を設定してください。パラメーターのヒストグラムは backward pass の hook を使わずに計算されるため、より低コストです。
watch の停止
トレーニングの途中で勾配のログ記録を停止するには:
Experiments メトリクス Runs