wandb.log() dans votre boucle d’entraînement comme dans n’importe quel autre flux de travail Python. Convertir les tableaux avant la journalisation garantit que les valeurs sont correctement sérialisées.
Pour la configuration des expériences et les schémas de journalisation, voir Créer une expérience et Journaliser des objets et des médias. Pour les points de contrôle du modèle Flax, voir Artifacts.
Les sections ci-dessous montrent le schéma de journalisation de base pour une boucle d’entraînement JAX, comment agréger les métriques de validation, comment enregistrer des points de contrôle du modèle JAX ou Flax comme artefacts W&B, et comment signaler les pertes NaN pendant l’entraînement.
Consigner des métriques depuis une boucle d’entraînement JAX
@jax.jit sont des tableaux sur le périphérique. Transmettez les scalaires à run.log() avec float(loss) ou .item() sur des tableaux de dimension 0. La journalisation d’un tableau JAX brut peut échouer lors de la sérialisation ou enregistrer une valeur inattendue selon la version de votre SDK.
Journaliser les métriques de validation
Enregistrer des points de contrôle du modèle comme artefacts
orbax ou flax.serialization, puis enregistrez-les comme artefacts :
Déboguer les valeurs NaN
Experiments Runs Métriques