Skip to main content
PyTorch は Python で最も人気のある ディープラーニング フレームワーク の一つであり、特に研究者の間で広く利用されています。 W&B は、 勾配 の ログ 記録から CPU および GPU での コード のプロファイリングまで、 PyTorch を第一級市民としてサポートしています。 また、 example repo では スクリプト の例を確認できます。これには、 Fashion MNISTHyperband を使用した ハイパーパラメーター 最適化の例や、それによって生成された W&B Dashboard が含まれています。

run.watch による 勾配 の ログ 記録

勾配 を自動的に ログ 記録するには、 wandb.Run.watch() を呼び出し、 PyTorch モデル を渡します。
import wandb

with wandb.init(config=args) as run:

    model = ...  # モデルのセットアップ

    # マジックメソッド
    run.watch(model, log_freq=100)

    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            run.log({"loss": loss})
同じ スクリプト 内で複数の モデル を追跡する必要がある場合は、各 モデル に対して個別に wandb.Run.watch() を呼び出すことができます。
勾配、 メトリクス、およびグラフは、順伝播(forward pass) および 逆伝播(backward pass)の後に wandb.Run.log() が呼び出されるまで ログ 記録されません。

画像とメディアの ログ 記録

画像 データを含む PyTorch の Tensorswandb.Image に渡すと、 torchvision のユーティリティが使用され、自動的に画像に変換されます。
with wandb.init(project="my_project", entity="my_entity") as run:
    images_t = ...  # PyTorch Tensorsとして画像を生成またはロード
    run.log({"examples": [wandb.Image(im) for im in images_t]})
PyTorch やその他の フレームワーク でリッチメディアを W&B に ログ 記録する方法の詳細については、 メディアロギングガイド を参照してください。 メディアと一緒に、 モデル の 予測 や派生した メトリクス などの 情報 を含めたい場合は、 wandb.Table を使用します。
with wandb.init() as run:
    my_table = wandb.Table()

    my_table.add_column("image", images_t)
    my_table.add_column("label", labels)
    my_table.add_column("class_prediction", predictions_t)

    # W&BにTableをログ記録
    run.log({"mnist_predictions": my_table})
PyTorch model results
データセット や モデル の ログ 記録と可視化の詳細については、 W&B Tables ガイド を参照してください。

PyTorch コード のプロファイリング

PyTorch execution traces
W&B は PyTorch KinetoTensorboard プラグイン と直接連携し、 PyTorch コード のプロファイリング、 CPU および GPU 通信の詳細な検査、ボトルネックの特定と最適化のための ツール を提供します。
profile_dir = "path/to/run/tbprofile/"
profiler = torch.profiler.profile(
    schedule=schedule,  # スケジューリングの詳細についてはprofilerのドキュメントを参照
    on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir),
    with_stack=True,
)

with profiler:
    ...  # ここでプロファイリングしたいコードを実行
    # 詳細な使用方法についてはprofilerのドキュメントを参照

# wandb Artifactを作成
profile_art = wandb.Artifact("trace", type="profile")
# pt.trace.jsonファイルをArtifactに追加
profile_art.add_file(glob.glob(profile_dir + ".pt.trace.json"))
# artifactを保存
profile_art.save()
動作するサンプル コード は、 この Colab で確認および実行できます。
インタラクティブな トレース 閲覧 ツール は Chrome Trace Viewer に基づいており、 Google Chrome ブラウザで最適に動作します。