Skip to main content
W&B ne dispose pas d’intégration spécifique à JAX. Vous pouvez toutefois convertir les tableaux JAX sur périphérique en scalaires Python, puis utiliser wandb.log() dans votre boucle d’entraînement comme dans n’importe quel autre flux de travail Python. Convertir les tableaux avant la journalisation garantit que les valeurs sont correctement sérialisées. Pour la configuration des expériences et les schémas de journalisation, voir Créer une expérience et Journaliser des objets et des médias. Pour les points de contrôle du modèle Flax, voir Artifacts. Les sections ci-dessous montrent le schéma de journalisation de base pour une boucle d’entraînement JAX, comment agréger les métriques de validation, comment enregistrer des points de contrôle du modèle JAX ou Flax comme artefacts W&B, et comment signaler les pertes NaN pendant l’entraînement.

Consigner des métriques depuis une boucle d’entraînement JAX

L’exemple suivant exécute une boucle d’entraînement JAX compilée en JIT, qui met à jour les paramètres du modèle à chaque lot et consigne la perte d’entraînement à chaque étape dans un run W&B.
import jax
import wandb

with wandb.init(
    project="my-jax-project",
    config={"learning_rate": 1e-3, "batch_size": 64, "epochs": 50},
) as run:
    lr = run.config.learning_rate

    @jax.jit
    def train_step(params, batch):
        loss, grads = jax.value_and_grad(loss_fn)(params, batch)
        params = update_params(params, grads)
        return params, loss

    for step, batch in enumerate(dataloader):
        params, loss = train_step(params, batch)
        run.log({"train/loss": float(loss)}, step=step)
Les valeurs renvoyées par @jax.jit sont des tableaux sur le périphérique. Transmettez les scalaires à run.log() avec float(loss) ou .item() sur des tableaux de dimension 0. La journalisation d’un tableau JAX brut peut échouer lors de la sérialisation ou enregistrer une valeur inattendue selon la version de votre SDK.

Journaliser les métriques de validation

Agrégez les métriques de validation en Python, puis journalisez-les une fois par époque :
for epoch in range(num_epochs):
    val_losses = []
    for val_batch in val_loader:
        val_loss = eval_step(params, val_batch)
        val_losses.append(float(val_loss))

    run.log({
        "epoch": epoch,
        "val/loss": sum(val_losses) / len(val_losses),
    })

Enregistrer des points de contrôle du modèle comme artefacts

Enregistrez les paramètres JAX ou Flax avec orbax ou flax.serialization, puis enregistrez-les comme artefacts :
import orbax.checkpoint as ocp

checkpointer = ocp.StandardCheckpointer()
checkpointer.save("/tmp/checkpoint", params)
checkpointer.wait_until_finished()

artifact = wandb.Artifact("jax-model", type="model")
artifact.add_dir("/tmp/checkpoint")
run.log_artifact(artifact)

Déboguer les valeurs NaN

Par défaut, JAX ne lève pas d’exception pour les valeurs NaN. Enregistrez dans le journal un indicateur NaN avec votre perte :
import jax.numpy as jnp

run.log({
    "train/loss": float(loss),
    "train/loss_is_nan": bool(jnp.isnan(loss)),
})
Pendant le développement, vous pouvez activer la vérification des NaN en mode débogage dans JAX (au prix d’une baisse de performances) :
from jax import config
config.update("jax_debug_nans", True)

Experiments Runs Métriques