> ## Documentation Index
> Fetch the complete documentation index at: https://docs.wandb.ai/llms.txt
> Use this file to discover all available pages before exploring further.

# W&B를 JAX와 함께 사용하려면 어떻게 해야 하나요?

W\&B에는 JAX 전용 인테그레이션이 없습니다. 하지만 JAX 디바이스 배열을 Python 스칼라로 변환한 다음, 다른 Python 워크플로와 마찬가지로 트레이닝 루프에서 `wandb.log()`를 사용할 수 있습니다. 로깅 전에 배열을 변환하면 값이 올바르게 직렬화됩니다.

실험 설정과 로깅 패턴은 [실험 만들기](/ko/models/track/create-an-experiment) 및 [객체와 미디어 로깅](/ko/models/track/log)을 참조하세요. Flax 체크포인트에 대해서는 [Artifacts](/ko/models/artifacts/)를 참조하세요.

아래 섹션에서는 JAX 트레이닝 루프의 기본 로깅 패턴, 검증 메트릭을 집계하는 방법, JAX 또는 Flax 체크포인트를 W\&B 아티팩트로 저장하는 방법, 그리고 트레이닝 중 NaN loss를 드러내는 방법을 설명합니다.

<div id="log-metrics-from-a-jax-training-loop">
  ## JAX 트레이닝 루프에서 메트릭 로깅하기
</div>

다음 예시에서는 JIT 컴파일된 JAX 트레이닝 루프를 실행하여 각 batch마다 모델 파라미터를 업데이트하고, step별 트레이닝 loss를 W\&B run에 로깅합니다.

```python theme={null}
import jax
import wandb

with wandb.init(
    project="my-jax-project",
    config={"learning_rate": 1e-3, "batch_size": 64, "epochs": 50},
) as run:
    lr = run.config.learning_rate

    @jax.jit
    def train_step(params, batch):
        loss, grads = jax.value_and_grad(loss_fn)(params, batch)
        params = update_params(params, grads)
        return params, loss

    for step, batch in enumerate(dataloader):
        params, loss = train_step(params, batch)
        run.log({"train/loss": float(loss)}, step=step)
```

`@jax.jit`에서 반환되는 값은 디바이스 배열입니다. 스칼라 값은 `float(loss)` 또는 0차원 배열에서 `.item()`을 사용해 `run.log()`에 전달하세요. 원시 JAX 배열을 그대로 로깅하면 SDK 버전에 따라 직렬화가 실패하거나 예상치 못한 값이 기록될 수 있습니다.

<div id="log-validation-metrics">
  ## 검증 메트릭 로깅
</div>

Python에서 검증 메트릭을 집계한 뒤 에포크마다 한 번씩 로깅합니다:

```python theme={null}
for epoch in range(num_epochs):
    val_losses = []
    for val_batch in val_loader:
        val_loss = eval_step(params, val_batch)
        val_losses.append(float(val_loss))

    run.log({
        "epoch": epoch,
        "val/loss": sum(val_losses) / len(val_losses),
    })
```

<div id="save-checkpoints-as-artifacts">
  ## 체크포인트를 아티팩트로 저장
</div>

`orbax` 또는 `flax.serialization`을 사용해 JAX 또는 Flax 파라미터를 저장한 다음, 이를 아티팩트로 로깅하세요:

```python theme={null}
import orbax.checkpoint as ocp

checkpointer = ocp.StandardCheckpointer()
checkpointer.save("/tmp/checkpoint", params)
checkpointer.wait_until_finished()

artifact = wandb.Artifact("jax-model", type="model")
artifact.add_dir("/tmp/checkpoint")
run.log_artifact(artifact)
```

<div id="debug-nan-values">
  ## NaN 값 디버깅하기
</div>

JAX는 기본적으로 NaN 값이 발생해도 오류를 발생시키지 않습니다. `loss`와 함께 NaN 플래그를 로깅하세요:

```python theme={null}
import jax.numpy as jnp

run.log({
    "train/loss": float(loss),
    "train/loss_is_nan": bool(jnp.isnan(loss)),
})
```

개발 중에는 성능 저하를 감수하고 JAX의 디버그 NaN 검사를 활성화할 수 있습니다:

```python theme={null}
from jax import config
config.update("jax_debug_nans", True)
```

***

<Badge stroke shape="pill" color="orange" size="md">[Experiments](/ko/support/models/tags/experiments)</Badge><Badge stroke shape="pill" color="orange" size="md">[Runs](/ko/support/models/tags/runs)</Badge><Badge stroke shape="pill" color="orange" size="md">[메트릭](/ko/support/models/tags/metrics)</Badge>
