wandb.log() を使用できます。ログ前に配列を変換することで、値を正しくシリアライズできます。
実験の設定とログのパターンについては、実験を作成する と オブジェクトとメディアをログする を参照してください。Flax のチェックポイントについては、アーティファクト を参照してください。
以下のセクションでは、JAX のトレーニングループにおける基本的なログのパターン、検証メトリクスを集約する方法、JAX または Flax のチェックポイントを W&B アーティファクトとして保存する方法、そしてトレーニング中に NaN の損失を検出する方法を示します。
JAX のトレーニングループからメトリクスをログする
@jax.jit から返される値はデバイス配列です。スカラー値を run.log() に渡すには、float(loss) または 0 次元配列に対して .item() を使用してください。JAX 配列をそのままログすると、SDK のバージョンによってはシリアライズに失敗したり、想定外の値が記録されたりすることがあります。
検証メトリクスをログする
チェックポイントをアーティファクトとして保存する
orbax または flax.serialization を使用して JAX または Flax のパラメーターを保存し、アーティファクトとしてログします:
NaN 値をデバッグする
Experiments Runs メトリクス