def train(
train_dataset,
val_dataset,
model,
optimizer,
train_acc_metric,
val_acc_metric,
epochs=10,
log_step=200,
val_log_step=50,
):
# wandb.init으로 run을 시작하고 프로젝트 및 설정 정보를 전달합니다.
run = wandb.init(
project="my-tf-integration",
config={
"epochs": epochs,
"log_step": log_step,
"val_log_step": val_log_step,
"architecture": "MLP",
"dataset": "MNIST",
},
)
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
train_loss = []
val_loss = []
# 데이터셋의 배치에 대해 반복합니다.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
loss_value = train_step(
x_batch_train,
y_batch_train,
model,
optimizer,
loss_fn,
train_acc_metric,
)
train_loss.append(float(loss_value))
# 각 에포크가 끝날 때 검증 루프를 실행합니다.
for step, (x_batch_val, y_batch_val) in enumerate(val_dataset):
val_loss_value = test_step(
x_batch_val, y_batch_val, model, loss_fn, val_acc_metric
)
val_loss.append(float(val_loss_value))
# 에포크 종료 시 메트릭을 표시합니다.
train_acc = train_acc_metric.result()
print("Training acc over epoch: %.4f" % (float(train_acc),))
val_acc = val_acc_metric.result()
print("Validation acc: %.4f" % (float(val_acc),))
# 에포크 종료 시 메트릭 상태를 리셋합니다.
train_acc_metric.reset_state()
val_acc_metric.reset_state()
# run.log()를 사용하여 메트릭 로그를 기록합니다.
run.log(
{
"epochs": epoch,
"loss": np.mean(train_loss),
"acc": float(train_acc),
"val_loss": np.mean(val_loss),
"val_acc": float(val_acc),
}
)
run.finish()