Documentation
Search…
PyTorch Lightning
Build scalable, structured, high-performance PyTorch models with Lightning and log them with W&B.
PyTorch Lightning provides a lightweight wrapper for organizing your PyTorch code and easily adding advanced features such as distributed training and 16-bit precision. W&B provides a lightweight wrapper for logging your ML experiments. But you don't need to combine the two yourself: we're incorporated directly into the PyTorch Lightning library, so you can always check out their documentation for reference information on the API.

⚡ Get going lightning-fast with just two lines.

1
from pytorch_lightning.loggers import WandbLogger # newline 1
2
from pytorch_lightning import Trainer
3
4
wandb_logger = WandbLogger() # newline 2
5
trainer = Trainer(logger=wandb_logger)
Copied!

Check out interactive examples!

Colab + Video Tutorial
Kaggle Kernel
Blog Posts
Run GPU-accelerated PyTorch Lighting plus W&B logging without installing anything using this Colab. And follow along with a video tutorial!
See how PyTorch Lighting and W&B can accelerate your model development and help you climb the leaderboard with this Kaggle Kernel.
Read more on specific topics in these blog posts made with Weights & Biases' Reports:

Frequently Asked Questions

How does W&B integrate with Lightning?

The core integration is based on the Lightning loggers API, which lets you write much of your logging code in a framework-agnostic way. Loggers are passed to the Lightning Trainer and are triggered based on that API's rich hook-and-callback system. This keeps your research code well-separated from engineering and logging code.

What does the integration log without any additional code?

We'll save your model checkpoints to W&B, where you can view them with the Netron model viewer or download them for use in future runs. We'll also capture system metrics, like GPU usage and network I/O, environment information, like hardware and OS information, code state (including git commit and diff patch, notebook contents and session history), and anything printed to the standard out.

How do I log scalar metrics, like accuracy, mIoU, and SSIM?

Because the WandbLogger is part of the broader Lightning loggers API, logging of scalar values to W&B can be done in a framework-agnostic way: just call self.log.
You can calculate these scalar metrics with Lightning's Metric API. In addition to providing robust and tested methods for calculating quantities like accuracy and signal-to-noise ratio, Metrics do lots of work under the hood, like maintaining state for efficient epoch-wise calculation and abstracting away device management. The code snippet below shows best practices for defining LightningModules so that metric calculation and logging works regardless of device or parallelism strategies used. That way you can get the most out of PyTorch Lightning's advanced features for high-performance code without compromising on logging.
1
import pytorch_lightning as pl
2
3
4
class MyLitModule(pl.LightningModule):
5
6
def __init__(self, *args, **kwargs):
7
# initialize module here
8
acc = pl.metrics.Accuracy()
9
# use .clone so that each metric can maintain its own state
10
self.train_acc = acc.clone()
11
# assign all metrics as attributes of module so they are detected as children
12
self.valid_acc = acc.clone()
13
14
def training_step(self, batch, batch_idx):
15
inputs, targets = batch
16
preds = self(inputs)
17
# return a dictionary
18
return {"loss": loss, "preds": preds, "targets": targets}
19
20
def training_step_end(self, outs):
21
# log accuracy on each step_end, for compatibility with data-parallel
22
self.train_acc(outs["preds"], outs["targets"])
23
self.log({"train/acc_step": self.train_acc})
24
25
def training_epoch_end(self, outs):
26
# additional log mean accuracy at the end of the epoch
27
self.log("train/acc_epoch", self.train_acc.compute())
Copied!
Lightning's Metrics are being transferred into a stand-alone library, torchmetrics, and will be unavailable in the base package starting with version 1.5. Read more here.

How do I log media objects?

Weights & Biases provides a wide variety of data types for rich media logging (read the guide or check the reference docs for more).
Unlike scalars, media objects are logged differently by each framework. To keep this more involved logging code separate from the core logic of your research code, use Lightning Callbacks.
Inside your Callback, you can either call wandb.log, as when using wandb with other libraries, or trainer.logger.experiment.log. In either case, you can do anything you could do with wandb.log.
When manually calling wandb.log or trainer.logger.experiment.log, make sure to include the key/value pair "global_step": trainer.global_step. That way, you can line up the information you're currently logging with information logged via other methods.
Image Logging
Image Classification Logging
Log the input and output images of an autoencoder or other image-to-image transformation network. Input-output pairs are combined into single images.
Outputs (top) for given inputs (bottom) of an auto-encoder trained on MNIST. ReLU troubles!
1
import pytorch_lightning as pl
2
from pytorch_lightning.loggers import WandbLogger
3
import torch
4
import wandb
5
6
7
class WandbImageCallback(pl.Callback):
8
"""Logs the input and output images of a module.
9
10
Images are stacked into a mosaic, with output on the top
11
and input on the bottom."""
12
13
def __init__(self, val_samples, max_samples=32):
14
super().__init__()
15
self.val_imgs, _ = val_samples
16
self.val_imgs = self.val_imgs[:max_samples]
17
18
def on_validation_end(self, trainer, pl_module):
19
val_imgs = self.val_imgs.to(device=pl_module.device)
20
21
outs = pl_module(val_imgs)
22
23
mosaics = torch.cat([outs, val_imgs], dim=-2)
24
caption = "Top: Output, Bottom: Input"
25
trainer.logger.experiment.log({
26
"val/examples": [wandb.Image(mosaic, caption=caption)
27
for mosaic in mosaics],
28
"global_step": trainer.global_step
29
})
30
31
...
32
33
trainer = pl.Trainer(
34
...
35
callbacks=[WandbImageCallback(val_samples)]
36
)
Copied!
Logs the input image and the output label for a single-class classification network.
Images and labels for a classifier trained on MNIST. Look for the mistake!
1
import pytorch_lightning as pl
2
from pytorch_lightning.loggers import WandbLogger
3
import torch
4
import wandb
5
6
7
class WandbImagePredCallback(pl.Callback):
8
"""Logs the input images and output predictions of a module.
9
10
Predictions and labels are logged as class indices."""
11
12
def __init__(self, val_samples, num_samples=32):
13
super().__init__()
14
self.val_imgs, self.val_labels = val_samples
15
self.val_imgs = self.val_imgs[:num_samples]
16
self.val_labels = self.val_labels[:num_samples]
17
18
def on_validation_epoch_end(self, trainer, pl_module):
19
val_imgs = self.val_imgs.to(device=pl_module.device)
20
21
logits = pl_module(val_imgs)
22
preds = torch.argmax(logits, 1)
23
24
trainer.logger.experiment.log({
25
"val/examples": [
26
wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")
27
for x, pred, y in zip(val_imgs, preds, self.val_labels)
28
],
29
"global_step": trainer.global_step
30
})
31
32
...
33
34
trainer = pl.Trainer(
35
...
36
callbacks=[WandbImagePredCallback(val_samples)]
37
)
Copied!

How to use multiple GPUs with Lightning and W&B?

PyTorch Lightning has Multi-GPU support through their DDP Interface. However, PyTorch Lightning's design requires us to be careful about how we instantiate our GPUs.
Lightning assumes that each GPU (or Rank) in your training loop must be instantiated in exactly the same way - with the same initial conditions. However, only rank 0 process gets access to the wandb.run object, and for non-zero rank processes: wandb.run = None. This could cause your non-zero processes to fail. Such a situation can put you in a deadlock because rank 0 process will wait for the non-zero rank processes to join, which have already crashed.
For this reason, we have to be careful about how we set up our training code. The recommended way to set it up would be to have your code be independent of the wandb.run object.
1
class MNISTClassifier(pl.LightningModule):
2
def __init__(self):
3
super(MNISTClassifier, self).__init__()
4
5
self.model = nn.Sequential(
6
nn.Flatten(),
7
nn.Linear(28 * 28, 128),
8
nn.ReLU(),
9
nn.Linear(128, 10),
10
)
11
12
self.loss = nn.CrossEntropyLoss()
13
14
def forward(self, x):
15
return self.model(x)
16
17
def training_step(self, batch, batch_idx):
18
x, y = batch
19
y_hat = self.forward(x)
20
loss = self.loss(y_hat, y)
21
22
self.log("train/loss", loss)
23
return {"train_loss": loss}
24
25
def validation_step(self, batch, batch_idx):
26
x, y = batch
27
y_hat = self.forward(x)
28
loss = self.loss(y_hat, y)
29
30
self.log("val/loss", loss)
31
return {"val_loss": loss}
32
33
def configure_optimizers(self):
34
return torch.optim.Adam(self.parameters(), lr=0.001)
35
36
def main():
37
# Setting all the random seeds to the same value.
38
# This is important in a distributed training setting.
39
# Each rank will get its own set of initial weights.
40
# If they don't match up, the gradients will not match either,
41
# leading to training that may not converge.
42
pl.seed_everything(1)
43
44
train_loader = DataLoader(train_dataset, batch_size = 64,
45
shuffle = True,
46
num_workers = 4)
47
val_loader = DataLoader(val_dataset,
48
batch_size = 64,
49
shuffle = False,
50
num_workers = 4)
51
52
model = MNISTClassifier()
53
pl_logger = WandbLogger(project = "<project_name>")
54
callbacks = [
55
ModelCheckpoint(
56
dirpath = "checkpoints",
57
every_n_train_steps=100,
58
),
59
]
60
trainer = pl.Trainer(
61
max_epochs = 3,
62
gpus = 2,
63
logger = pl_logger,
64
strategy="ddp",
65
callbacks=callbacks
66
)
67
trainer.fit(model, train_loader, val_loader)
Copied!

But what if I really need to use wandb.run in my training setup?

You will have to essentially expand the scope of the variable you need to access yourself. In other words, making sure that the inital conditions are the same on all processes.
1
if os.environ.get("LOCAL_RANK", None) is None:
2
os.environ["WANDB_DIR"] = wandb.run.dir
Copied!
Then, you can use os.environ["WANDB_DIR"] to set up the model checkpoints directory. This way, wandb.run.dir can be used by any non-zero rank processes as well.
Last modified 1mo ago