Skip to main content
W&B には JAX 固有のインテグレーションはありません。ただし、JAX のデバイス配列を Python スカラーに変換してから、他の Python ワークフローと同様にトレーニングループ内で wandb.log() を使用できます。ログ前に配列を変換することで、値を正しくシリアライズできます。 実験の設定とログのパターンについては、実験を作成するオブジェクトとメディアをログする を参照してください。Flax のチェックポイントについては、アーティファクト を参照してください。 以下のセクションでは、JAX のトレーニングループにおける基本的なログのパターン、検証メトリクスを集約する方法、JAX または Flax のチェックポイントを W&B アーティファクトとして保存する方法、そしてトレーニング中に NaN の損失を検出する方法を示します。

JAX のトレーニングループからメトリクスをログする

次の例では、JIT コンパイル済みの JAX トレーニングループを実行し、各バッチでモデルのパラメーターを更新しながら、ステップごとのトレーニング損失を W&B run にログします。
import jax
import wandb

with wandb.init(
    project="my-jax-project",
    config={"learning_rate": 1e-3, "batch_size": 64, "epochs": 50},
) as run:
    lr = run.config.learning_rate

    @jax.jit
    def train_step(params, batch):
        loss, grads = jax.value_and_grad(loss_fn)(params, batch)
        params = update_params(params, grads)
        return params, loss

    for step, batch in enumerate(dataloader):
        params, loss = train_step(params, batch)
        run.log({"train/loss": float(loss)}, step=step)
@jax.jit から返される値はデバイス配列です。スカラー値を run.log() に渡すには、float(loss) または 0 次元配列に対して .item() を使用してください。JAX 配列をそのままログすると、SDK のバージョンによってはシリアライズに失敗したり、想定外の値が記録されたりすることがあります。

検証メトリクスをログする

Python で検証メトリクスを集計し、エポックごとに 1 回ログします:
for epoch in range(num_epochs):
    val_losses = []
    for val_batch in val_loader:
        val_loss = eval_step(params, val_batch)
        val_losses.append(float(val_loss))

    run.log({
        "epoch": epoch,
        "val/loss": sum(val_losses) / len(val_losses),
    })

チェックポイントをアーティファクトとして保存する

orbax または flax.serialization を使用して JAX または Flax のパラメーターを保存し、アーティファクトとしてログします:
import orbax.checkpoint as ocp

checkpointer = ocp.StandardCheckpointer()
checkpointer.save("/tmp/checkpoint", params)
checkpointer.wait_until_finished()

artifact = wandb.Artifact("jax-model", type="model")
artifact.add_dir("/tmp/checkpoint")
run.log_artifact(artifact)

NaN 値をデバッグする

JAX では、デフォルトで NaN 値が発生しても例外は送出されません。損失と一緒に NaN フラグをログしてください。
import jax.numpy as jnp

run.log({
    "train/loss": float(loss),
    "train/loss_is_nan": bool(jnp.isnan(loss)),
})
開発中は、JAX のデバッグ用 NaN チェックを有効にできます (パフォーマンスに影響があります) :
from jax import config
config.update("jax_debug_nans", True)

Experiments Runs メトリクス