> ## Documentation Index
> Fetch the complete documentation index at: https://docs.wandb.ai/llms.txt
> Use this file to discover all available pages before exploring further.

# W&B を JAX で使用するにはどうすればよいですか？

W\&B には JAX 固有のインテグレーションはありません。ただし、JAX のデバイス配列を Python スカラーに変換してから、他の Python ワークフローと同様にトレーニングループ内で `wandb.log()` を使用できます。ログ前に配列を変換することで、値を正しくシリアライズできます。

実験の設定とログのパターンについては、[実験を作成する](/ja/models/track/create-an-experiment) と [オブジェクトとメディアをログする](/ja/models/track/log) を参照してください。Flax のチェックポイントについては、[アーティファクト](/ja/models/artifacts/) を参照してください。

以下のセクションでは、JAX のトレーニングループにおける基本的なログのパターン、検証メトリクスを集約する方法、JAX または Flax のチェックポイントを W\&B アーティファクトとして保存する方法、そしてトレーニング中に NaN の損失を検出する方法を示します。

<div id="log-metrics-from-a-jax-training-loop">
  ## JAX のトレーニングループからメトリクスをログする
</div>

次の例では、JIT コンパイル済みの JAX トレーニングループを実行し、各バッチでモデルのパラメーターを更新しながら、ステップごとのトレーニング損失を W\&B run にログします。

```python theme={null}
import jax
import wandb

with wandb.init(
    project="my-jax-project",
    config={"learning_rate": 1e-3, "batch_size": 64, "epochs": 50},
) as run:
    lr = run.config.learning_rate

    @jax.jit
    def train_step(params, batch):
        loss, grads = jax.value_and_grad(loss_fn)(params, batch)
        params = update_params(params, grads)
        return params, loss

    for step, batch in enumerate(dataloader):
        params, loss = train_step(params, batch)
        run.log({"train/loss": float(loss)}, step=step)
```

`@jax.jit` から返される値はデバイス配列です。スカラー値を `run.log()` に渡すには、`float(loss)` または 0 次元配列に対して `.item()` を使用してください。JAX 配列をそのままログすると、SDK のバージョンによってはシリアライズに失敗したり、想定外の値が記録されたりすることがあります。

<div id="log-validation-metrics">
  ## 検証メトリクスをログする
</div>

Python で検証メトリクスを集計し、エポックごとに 1 回ログします:

```python theme={null}
for epoch in range(num_epochs):
    val_losses = []
    for val_batch in val_loader:
        val_loss = eval_step(params, val_batch)
        val_losses.append(float(val_loss))

    run.log({
        "epoch": epoch,
        "val/loss": sum(val_losses) / len(val_losses),
    })
```

<div id="save-checkpoints-as-artifacts">
  ## チェックポイントをアーティファクトとして保存する
</div>

`orbax` または `flax.serialization` を使用して JAX または Flax のパラメーターを保存し、アーティファクトとしてログします:

```python theme={null}
import orbax.checkpoint as ocp

checkpointer = ocp.StandardCheckpointer()
checkpointer.save("/tmp/checkpoint", params)
checkpointer.wait_until_finished()

artifact = wandb.Artifact("jax-model", type="model")
artifact.add_dir("/tmp/checkpoint")
run.log_artifact(artifact)
```

<div id="debug-nan-values">
  ## NaN 値をデバッグする
</div>

JAX では、デフォルトで NaN 値が発生しても例外は送出されません。損失と一緒に NaN フラグをログしてください。

```python theme={null}
import jax.numpy as jnp

run.log({
    "train/loss": float(loss),
    "train/loss_is_nan": bool(jnp.isnan(loss)),
})
```

開発中は、JAX のデバッグ用 NaN チェックを有効にできます (パフォーマンスに影響があります) ：

```python theme={null}
from jax import config
config.update("jax_debug_nans", True)
```

***

<Badge stroke shape="pill" color="orange" size="md">[Experiments](/ja/support/models/tags/experiments)</Badge><Badge stroke shape="pill" color="orange" size="md">[Runs](/ja/support/models/tags/runs)</Badge><Badge stroke shape="pill" color="orange" size="md">[メトリクス](/ja/support/models/tags/metrics)</Badge>
