wandb.log() in your training loop like any other Python workflow. Converting the arrays before logging makes it so values serialize correctly.
For experiment setup and logging patterns, see Create an experiment and Log objects and media. For Flax checkpoints, see Artifacts.
The sections below show the basic logging pattern for a JAX training loop, how to aggregate validation metrics, how to save JAX or Flax checkpoints as W&B artifacts, and how to surface NaN losses during training.
Log metrics from a JAX training loop
The following example runs a JIT-compiled JAX training loop that updates model parameters on each batch and logs the per-step training loss to a W&B run.@jax.jit are device arrays. Pass scalars to run.log() with float(loss) or .item() on 0-dimensional arrays. Logging a raw JAX array can fail serialization or record an unexpected value depending on your SDK version.
Log validation metrics
Aggregate validation metrics in Python and log once per epoch:Save checkpoints as artifacts
Save JAX or Flax parameters withorbax or flax.serialization, then log them as artifacts:
Debug NaN values
JAX does not raise on NaN values by default. Log a NaN flag with your loss:Experiments Runs Metrics