wandb.log()를 사용할 수 있습니다. 로깅 전에 배열을 변환하면 값이 올바르게 직렬화됩니다.
실험 설정과 로깅 패턴은 실험 만들기 및 객체와 미디어 로깅을 참조하세요. Flax 체크포인트에 대해서는 Artifacts를 참조하세요.
아래 섹션에서는 JAX 트레이닝 루프의 기본 로깅 패턴, 검증 메트릭을 집계하는 방법, JAX 또는 Flax 체크포인트를 W&B 아티팩트로 저장하는 방법, 그리고 트레이닝 중 NaN loss를 드러내는 방법을 설명합니다.
JAX 트레이닝 루프에서 메트릭 로깅하기
@jax.jit에서 반환되는 값은 디바이스 배열입니다. 스칼라 값은 float(loss) 또는 0차원 배열에서 .item()을 사용해 run.log()에 전달하세요. 원시 JAX 배열을 그대로 로깅하면 SDK 버전에 따라 직렬화가 실패하거나 예상치 못한 값이 기록될 수 있습니다.
검증 메트릭 로깅
체크포인트를 아티팩트로 저장
orbax 또는 flax.serialization을 사용해 JAX 또는 Flax 파라미터를 저장한 다음, 이를 아티팩트로 로깅하세요:
NaN 값 디버깅하기
loss와 함께 NaN 플래그를 로깅하세요:
Experiments Runs 메트릭