Skip to main content
W&B has no JAX-specific integration. However, you can convert JAX device arrays to Python scalars and then use wandb.log() in your training loop like any other Python workflow. Converting the arrays before logging makes it so values serialize correctly. For experiment setup and logging patterns, see Create an experiment and Log objects and media. For Flax checkpoints, see Artifacts. The sections below show the basic logging pattern for a JAX training loop, how to aggregate validation metrics, how to save JAX or Flax checkpoints as W&B artifacts, and how to surface NaN losses during training.

Log metrics from a JAX training loop

The following example runs a JIT-compiled JAX training loop that updates model parameters on each batch and logs the per-step training loss to a W&B run.
import jax
import wandb

with wandb.init(
    project="my-jax-project",
    config={"learning_rate": 1e-3, "batch_size": 64, "epochs": 50},
) as run:
    lr = run.config.learning_rate

    @jax.jit
    def train_step(params, batch):
        loss, grads = jax.value_and_grad(loss_fn)(params, batch)
        params = update_params(params, grads)
        return params, loss

    for step, batch in enumerate(dataloader):
        params, loss = train_step(params, batch)
        run.log({"train/loss": float(loss)}, step=step)
Values returned from @jax.jit are device arrays. Pass scalars to run.log() with float(loss) or .item() on 0-dimensional arrays. Logging a raw JAX array can fail serialization or record an unexpected value depending on your SDK version.

Log validation metrics

Aggregate validation metrics in Python and log once per epoch:
for epoch in range(num_epochs):
    val_losses = []
    for val_batch in val_loader:
        val_loss = eval_step(params, val_batch)
        val_losses.append(float(val_loss))

    run.log({
        "epoch": epoch,
        "val/loss": sum(val_losses) / len(val_losses),
    })

Save checkpoints as artifacts

Save JAX or Flax parameters with orbax or flax.serialization, then log them as artifacts:
import orbax.checkpoint as ocp

checkpointer = ocp.StandardCheckpointer()
checkpointer.save("/tmp/checkpoint", params)
checkpointer.wait_until_finished()

artifact = wandb.Artifact("jax-model", type="model")
artifact.add_dir("/tmp/checkpoint")
run.log_artifact(artifact)

Debug NaN values

JAX does not raise on NaN values by default. Log a NaN flag with your loss:
import jax.numpy as jnp

run.log({
    "train/loss": float(loss),
    "train/loss_is_nan": bool(jnp.isnan(loss)),
})
During development you can enable JAX debug NaN checking (at a performance cost):
from jax import config
config.update("jax_debug_nans", True)

Experiments Runs Metrics