Documentation
Search…
PyTorch Ignite
Ignite supports Weights & Biases handler to log metrics, model/optimizer parameters, gradients during training and validation. It can also be used to log model checkpoints to the Weights & Biases cloud. This class is also a wrapper for the wandb module. This means that you can call any wandb function using this wrapper. See examples on how to save model parameters and gradients.

The basic PyTorch setup

1
from argparse import ArgumentParser
2
import wandb
3
import torch
4
from torch import nn
5
from torch.optim import SGD
6
from torch.utils.data import DataLoader
7
import torch.nn.functional as F
8
from torchvision.transforms import Compose, ToTensor, Normalize
9
from torchvision.datasets import MNIST
10
11
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
12
from ignite.metrics import Accuracy, Loss
13
14
from tqdm import tqdm
15
16
17
class Net(nn.Module):
18
def __init__(self):
19
super(Net, self).__init__()
20
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
21
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
22
self.conv2_drop = nn.Dropout2d()
23
self.fc1 = nn.Linear(320, 50)
24
self.fc2 = nn.Linear(50, 10)
25
26
def forward(self, x):
27
x = F.relu(F.max_pool2d(self.conv1(x), 2))
28
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
29
x = x.view(-1, 320)
30
x = F.relu(self.fc1(x))
31
x = F.dropout(x, training=self.training)
32
x = self.fc2(x)
33
return F.log_softmax(x, dim=-1)
34
35
36
def get_data_loaders(train_batch_size, val_batch_size):
37
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
38
39
train_loader = DataLoader(MNIST(download=True, root=".", transform=data_transform, train=True),
40
batch_size=train_batch_size, shuffle=True)
41
42
val_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform, train=False),
43
batch_size=val_batch_size, shuffle=False)
44
return train_loader, val_loader
Copied!
Using WandBLogger in ignite is a 2-step modular process: First, you need to create a WandBLogger object. Then it can be attached to any trainer or evaluator to automatically log the metrics. We'll do the following tasks sequentially: 1) Create a WandBLogger object 2) Attach the Object to the output handlers to:
    Log training loss - attach to trainer object
    Log validation loss - attach to evaluator
    Log optional Parameters - Say, learning rate
    Watch the model
1
from ignite.contrib.handlers.wandb_logger import *
2
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
3
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
4
model = Net()
5
device = 'cpu'
6
7
if torch.cuda.is_available():
8
device = 'cuda'
9
10
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
11
trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
12
evaluator = create_supervised_evaluator(model,
13
metrics={'accuracy': Accuracy(),
14
'nll': Loss(F.nll_loss)},
15
device=device)
16
17
desc = "ITERATION - loss: {:.2f}"
18
pbar = tqdm(
19
initial=0, leave=False, total=len(train_loader),
20
desc=desc.format(0)
21
)
22
#WandBlogger Object Creation
23
wandb_logger = WandBLogger(
24
project="pytorch-ignite-integration",
25
name="cnn-mnist",
26
config={"max_epochs": epochs,"batch_size":train_batch_size},
27
tags=["pytorch-ignite", "minst"]
28
)
29
30
wandb_logger.attach_output_handler(
31
trainer,
32
event_name=Events.ITERATION_COMPLETED,
33
tag="training",
34
output_transform=lambda loss: {"loss": loss}
35
)
36
37
wandb_logger.attach_output_handler(
38
evaluator,
39
event_name=Events.EPOCH_COMPLETED,
40
tag="training",
41
metric_names=["nll", "accuracy"],
42
global_step_transform=lambda *_: trainer.state.iteration,
43
)
44
45
wandb_logger.attach_opt_params_handler(
46
trainer,
47
event_name=Events.ITERATION_STARTED,
48
optimizer=optimizer,
49
param_name='lr' # optional
50
)
51
52
wandb_logger.watch(model)
Copied!
Optionally, we can also utilize ignite EVENTS to log the metrics directly to the terminal
1
@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
2
def log_training_loss(engine):
3
pbar.desc = desc.format(engine.state.output)
4
pbar.update(log_interval)
5
6
@trainer.on(Events.EPOCH_COMPLETED)
7
def log_training_results(engine):
8
pbar.refresh()
9
evaluator.run(train_loader)
10
metrics = evaluator.state.metrics
11
avg_accuracy = metrics['accuracy']
12
avg_nll = metrics['nll']
13
tqdm.write(
14
"Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
15
.format(engine.state.epoch, avg_accuracy, avg_nll)
16
)
17
18
@trainer.on(Events.EPOCH_COMPLETED)
19
def log_validation_results(engine):
20
evaluator.run(val_loader)
21
metrics = evaluator.state.metrics
22
avg_accuracy = metrics['accuracy']
23
avg_nll = metrics['nll']
24
tqdm.write(
25
"Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
26
.format(engine.state.epoch, avg_accuracy, avg_nll))
27
28
pbar.n = pbar.last_print_n = 0
29
30
trainer.run(train_loader, max_epochs=epochs)
31
pbar.close()
32
33
34
if __name__ == "__main__":
35
parser = ArgumentParser()
36
parser.add_argument('--batch_size', type=int, default=64,
37
help='input batch size for training (default: 64)')
38
parser.add_argument('--val_batch_size', type=int, default=1000,
39
help='input batch size for validation (default: 1000)')
40
parser.add_argument('--epochs', type=int, default=10,
41
help='number of epochs to train (default: 10)')
42
parser.add_argument('--lr', type=float, default=0.01,
43
help='learning rate (default: 0.01)')
44
parser.add_argument('--momentum', type=float, default=0.5,
45
help='SGD momentum (default: 0.5)')
46
parser.add_argument('--log_interval', type=int, default=10,
47
help='how many batches to wait before logging training status')
48
49
args = parser.parse_args()
50
run(args.batch_size, args.val_batch_size, args.epochs, args.lr, args.momentum, args.log_interval)
Copied!
We get these visualizations on running the above code:
Refer Ignite Docs for more detailed documentation
Last modified 11mo ago
Copy link