# W&B: このモデルのトレーニングを追跡する新しい run を初期化します
wandb.init(project="table-quickstart")
# W&B: config を使用してハイパーパラメーターをログ
cfg = wandb.config
cfg.update({"epochs" : EPOCHS, "batch_size": BATCH_SIZE, "lr" : LEARNING_RATE,
"l1_size" : L1_SIZE, "l2_size": L2_SIZE,
"conv_kernel" : CONV_KERNEL_SIZE,
"img_count" : min(10000, NUM_IMAGES_PER_BATCH*NUM_BATCHES_TO_LOG)})
# モデル、損失関数、オプティマイザーを定義
model = ConvNet(NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# テスト画像のバッチの予測をログするための便利な関数
def log_test_predictions(images, labels, outputs, predicted, test_table, log_counter):
# すべてのクラスの信頼スコアを取得
scores = F.softmax(outputs.data, dim=1)
log_scores = scores.cpu().numpy()
log_images = images.cpu().numpy()
log_labels = labels.cpu().numpy()
log_preds = predicted.cpu().numpy()
# 画像の順序に基づいて ID を追加
_id = 0
for i, l, p, s in zip(log_images, log_labels, log_preds, log_scores):
# データテーブルに必要な情報を追加:
# ID、画像ピクセル、モデルの推測、真のラベル、すべてのクラスのスコア
img_id = str(_id) + "_" + str(log_counter)
test_table.add_data(img_id, wandb.Image(i), p, l, *s)
_id += 1
if _id == NUM_IMAGES_PER_BATCH:
break
# モデルをトレーニングする
total_step = len(train_loader)
for epoch in range(EPOCHS):
# トレーニングステップ
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
# forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
# W&B: トレーニングステップでの損失をログし、UIでライブ視覚化
wandb.log({"loss" : loss})
if (i+1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, EPOCHS, i+1, total_step, loss.item()))
# W&B: 各テストステップでの予測を保存するためのテーブルを作成
columns=["id", "image", "guess", "truth"]
for digit in range(10):
columns.append("score_" + str(digit))
test_table = wandb.Table(columns=columns)
# モデルをテスト
model.eval()
log_counter = 0
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
if log_counter < NUM_BATCHES_TO_LOG:
log_test_predictions(images, labels, outputs, predicted, test_table, log_counter)
log_counter += 1
total += labels.size(0)
correct += (predicted == labels).sum().item()
acc = 100 * correct / total
# W&B: トレーニングエポックの精度をログして、UIで可視化
wandb.log({"epoch" : epoch, "acc" : acc})
print('Test Accuracy of the model on the 10000 test images: {} %'.format(acc))
# W&B: 予測テーブルを wandb にログ
wandb.log({"test_predictions" : test_table})
# W&B: run を完了としてマークする(マルチセルノートブックに便利)
wandb.finish()