Skip to main content
W&B에는 JAX 전용 인테그레이션이 없습니다. 하지만 JAX 디바이스 배열을 Python 스칼라로 변환한 다음, 다른 Python 워크플로와 마찬가지로 트레이닝 루프에서 wandb.log()를 사용할 수 있습니다. 로깅 전에 배열을 변환하면 값이 올바르게 직렬화됩니다. 실험 설정과 로깅 패턴은 실험 만들기객체와 미디어 로깅을 참조하세요. Flax 체크포인트에 대해서는 Artifacts를 참조하세요. 아래 섹션에서는 JAX 트레이닝 루프의 기본 로깅 패턴, 검증 메트릭을 집계하는 방법, JAX 또는 Flax 체크포인트를 W&B 아티팩트로 저장하는 방법, 그리고 트레이닝 중 NaN loss를 드러내는 방법을 설명합니다.

JAX 트레이닝 루프에서 메트릭 로깅하기

다음 예시에서는 JIT 컴파일된 JAX 트레이닝 루프를 실행하여 각 batch마다 모델 파라미터를 업데이트하고, step별 트레이닝 loss를 W&B run에 로깅합니다.
import jax
import wandb

with wandb.init(
    project="my-jax-project",
    config={"learning_rate": 1e-3, "batch_size": 64, "epochs": 50},
) as run:
    lr = run.config.learning_rate

    @jax.jit
    def train_step(params, batch):
        loss, grads = jax.value_and_grad(loss_fn)(params, batch)
        params = update_params(params, grads)
        return params, loss

    for step, batch in enumerate(dataloader):
        params, loss = train_step(params, batch)
        run.log({"train/loss": float(loss)}, step=step)
@jax.jit에서 반환되는 값은 디바이스 배열입니다. 스칼라 값은 float(loss) 또는 0차원 배열에서 .item()을 사용해 run.log()에 전달하세요. 원시 JAX 배열을 그대로 로깅하면 SDK 버전에 따라 직렬화가 실패하거나 예상치 못한 값이 기록될 수 있습니다.

검증 메트릭 로깅

Python에서 검증 메트릭을 집계한 뒤 에포크마다 한 번씩 로깅합니다:
for epoch in range(num_epochs):
    val_losses = []
    for val_batch in val_loader:
        val_loss = eval_step(params, val_batch)
        val_losses.append(float(val_loss))

    run.log({
        "epoch": epoch,
        "val/loss": sum(val_losses) / len(val_losses),
    })

체크포인트를 아티팩트로 저장

orbax 또는 flax.serialization을 사용해 JAX 또는 Flax 파라미터를 저장한 다음, 이를 아티팩트로 로깅하세요:
import orbax.checkpoint as ocp

checkpointer = ocp.StandardCheckpointer()
checkpointer.save("/tmp/checkpoint", params)
checkpointer.wait_until_finished()

artifact = wandb.Artifact("jax-model", type="model")
artifact.add_dir("/tmp/checkpoint")
run.log_artifact(artifact)

NaN 값 디버깅하기

JAX는 기본적으로 NaN 값이 발생해도 오류를 발생시키지 않습니다. loss와 함께 NaN 플래그를 로깅하세요:
import jax.numpy as jnp

run.log({
    "train/loss": float(loss),
    "train/loss_is_nan": bool(jnp.isnan(loss)),
})
개발 중에는 성능 저하를 감수하고 JAX의 디버그 NaN 검사를 활성화할 수 있습니다:
from jax import config
config.update("jax_debug_nans", True)

Experiments Runs 메트릭