> ## Documentation Index
> Fetch the complete documentation index at: https://docs.wandb.ai/llms.txt
> Use this file to discover all available pages before exploring further.

# Comment puis-je utiliser W&B avec JAX ?

W\&B ne dispose pas d’intégration spécifique à JAX. Vous pouvez toutefois convertir les tableaux JAX sur périphérique en scalaires Python, puis utiliser `wandb.log()` dans votre boucle d’entraînement comme dans n’importe quel autre flux de travail Python. Convertir les tableaux avant la journalisation garantit que les valeurs sont correctement sérialisées.

Pour la configuration des expériences et les schémas de journalisation, voir [Créer une expérience](/fr/models/track/create-an-experiment) et [Journaliser des objets et des médias](/fr/models/track/log). Pour les points de contrôle du modèle Flax, voir [Artifacts](/fr/models/artifacts/).

Les sections ci-dessous montrent le schéma de journalisation de base pour une boucle d’entraînement JAX, comment agréger les métriques de validation, comment enregistrer des points de contrôle du modèle JAX ou Flax comme artefacts W\&B, et comment signaler les pertes NaN pendant l’entraînement.

<div id="log-metrics-from-a-jax-training-loop">
  ## Consigner des métriques depuis une boucle d’entraînement JAX
</div>

L’exemple suivant exécute une boucle d’entraînement JAX compilée en JIT, qui met à jour les paramètres du modèle à chaque lot et consigne la perte d’entraînement à chaque étape dans un run W\&B.

```python theme={null}
import jax
import wandb

with wandb.init(
    project="my-jax-project",
    config={"learning_rate": 1e-3, "batch_size": 64, "epochs": 50},
) as run:
    lr = run.config.learning_rate

    @jax.jit
    def train_step(params, batch):
        loss, grads = jax.value_and_grad(loss_fn)(params, batch)
        params = update_params(params, grads)
        return params, loss

    for step, batch in enumerate(dataloader):
        params, loss = train_step(params, batch)
        run.log({"train/loss": float(loss)}, step=step)
```

Les valeurs renvoyées par `@jax.jit` sont des tableaux sur le périphérique. Transmettez les scalaires à `run.log()` avec `float(loss)` ou `.item()` sur des tableaux de dimension 0. La journalisation d’un tableau JAX brut peut échouer lors de la sérialisation ou enregistrer une valeur inattendue selon la version de votre SDK.

<div id="log-validation-metrics">
  ## Journaliser les métriques de validation
</div>

Agrégez les métriques de validation en Python, puis journalisez-les une fois par époque :

```python theme={null}
for epoch in range(num_epochs):
    val_losses = []
    for val_batch in val_loader:
        val_loss = eval_step(params, val_batch)
        val_losses.append(float(val_loss))

    run.log({
        "epoch": epoch,
        "val/loss": sum(val_losses) / len(val_losses),
    })
```

<div id="save-checkpoints-as-artifacts">
  ## Enregistrer des points de contrôle du modèle comme artefacts
</div>

Enregistrez les paramètres JAX ou Flax avec `orbax` ou `flax.serialization`, puis enregistrez-les comme artefacts :

```python theme={null}
import orbax.checkpoint as ocp

checkpointer = ocp.StandardCheckpointer()
checkpointer.save("/tmp/checkpoint", params)
checkpointer.wait_until_finished()

artifact = wandb.Artifact("jax-model", type="model")
artifact.add_dir("/tmp/checkpoint")
run.log_artifact(artifact)
```

<div id="debug-nan-values">
  ## Déboguer les valeurs NaN
</div>

Par défaut, JAX ne lève pas d’exception pour les valeurs NaN. Enregistrez dans le journal un indicateur NaN avec votre perte :

```python theme={null}
import jax.numpy as jnp

run.log({
    "train/loss": float(loss),
    "train/loss_is_nan": bool(jnp.isnan(loss)),
})
```

Pendant le développement, vous pouvez activer la vérification des NaN en mode débogage dans JAX (au prix d'une baisse de performances) :

```python theme={null}
from jax import config
config.update("jax_debug_nans", True)
```

***

<Badge stroke shape="pill" color="orange" size="md">[Experiments](/fr/support/models/tags/experiments)</Badge><Badge stroke shape="pill" color="orange" size="md">[Runs](/fr/support/models/tags/runs)</Badge><Badge stroke shape="pill" color="orange" size="md">[Métriques](/fr/support/models/tags/metrics)</Badge>
