PyTorch Lightning provides a lightweight wrapper for organizing your PyTorch code and easily adding advanced features such as distributed training and 16-bit precision. W&B provides a lightweight wrapper for logging your ML experiments. We're incorporated directly into the PyTorch Lightning library, so you can always check out their documentation.
from pytorch_lightning.loggers import WandbLoggerfrom pytorch_lightning import Trainerwandb_logger = WandbLogger()trainer = Trainer(logger=wandb_logger)
We've created a few examples for you to see how the integration works:
Demo in Google Colab with hyperparameter optimization
Tutorial: Supercharge your Training with Pytorch Lightning + Weights & Biases
Semantic Segmentation with Lightning: optimize neural networks for self-driving cars
A step by step guide to tracking your Lightning model performance
Optional parameters:
name (str) – display name for the run.
save_dir (str) – path where data is saved (wandb dir by default).
offline (bool) – run offline (data can be streamed later to wandb servers).
id (str) – sets the version, mainly used to resume a previous run.
version (str) – same as version (legacy).
anonymous (bool) – enables or explicitly disables anonymous logging.
project (str) – the name of the project to which this run will belong.
log_model (bool) – save checkpoints in wandb dir to upload on W&B servers.
prefix (str) – string to put at the beginning of metric keys.
sync_step (bool) - Sync Trainer step with wandb step (True by default).
**kwargs – Additional arguments like entity
, group
, tags
, etc. used by wandb.init
can be passed as keyword arguments in this logger.
Log model topology as well as optionally gradients and weights.
wandb_logger.watch(model, log='gradients', log_freq=100)
Parameters:
model (nn.Module) – model to be logged.
log (str) – can be "gradients" (default), "parameters", "all" or None.
log_freq (int) – step count between logging of gradients and parameters (100 by default).
Record hyperparameter configuration.
Note: this function is called automatically when using LightningModule.save_hyperparameters()
wandb_logger.log_hyperparams(params)
Parameters:
params (dict) – dictionary with hyperparameter names as keys and configuration values as values
Record training metrics.
Note: this function is called automatically by LightningModule.log('metric', value)
wandb_logger.log_metrics(metrics, step=None)
Parameters:
metric (numeric) – dictionary with metric names as keys and measured quantities as values
step (int|None) – step number at which the metrics should be recorded
****