This is the multi-page printable view of this section. Click here to print.

Return to the regular view of this page.

Tutorials

Get started using Weights & Biases with interactive tutorials.

The following tutorials take you through the fundamentals of W&B for machine learning experiment tracking, model evaluation, hyperparameter tuning, model and dataset versioning, and more.

  1. Track experiments
  2. Visualize predictions
  3. Tune hyperparameters
  4. Track models and datasets

See the following tutorials for step by step information on how to use popular ML frameworks and libraries with W&B:

Other resources

Visit the W&B AI Academy to learn how to train, fine-tune and use LLMs in your applications. Implement MLOps and LLMOps solutions. Tackle real-world ML challenges with W&B courses.

1 - Track experiments

Use W&B for machine learning experiment tracking, model checkpointing, collaboration with your team and more.

In this notebook, you will create and track a machine learning experiment using a simple PyTorch model. By the end of the notebook, you will have an interactive project dashboard that you can share and customize with other members of your team. View an example dashboard here.

Prerequisites

Install the W&B Python SDK and log in:

!pip install wandb -qU
# Log in to your W&B account
import wandb
import random
import math

# Use wandb-core, temporary for wandb's new backend
wandb.require("core")
wandb.login()

Simulate and track a machine learning experiment with W&B

Create, track, and visualize a machine learning experiment. To do this:

  1. Initialize a W&B run and pass in the hyperparameters you want to track.
  2. Within your training loop, log metrics such as the accuruacy and loss.
import random
import math

# Launch 5 simulated experiments
total_runs = 5
for run in range(total_runs):
  # 1️. Start a new run to track this script
  wandb.init(
      # Set the project where this run will be logged
      project="basic-intro",
      # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
      name=f"experiment_{run}",
      # Track hyperparameters and run metadata
      config={
      "learning_rate": 0.02,
      "architecture": "CNN",
      "dataset": "CIFAR-100",
      "epochs": 10,
      })

  # This simple block simulates a training loop logging metrics
  epochs = 10
  offset = random.random() / 5
  for epoch in range(2, epochs):
      acc = 1 - 2 ** -epoch - random.random() / epoch - offset
      loss = 2 ** -epoch + random.random() / epoch + offset

      # 2️. Log metrics from your script to W&B
      wandb.log({"acc": acc, "loss": loss})

  # Mark the run as finished
  wandb.finish()

View how your machine learning peformed in your W&B project. Copy and paste the URL link that is printed from the previous cell. The URL will redirect you to a W&B project that contains a dashboard showing graphs the show how

The following image shows what a dashboard can look like:

Now that we know how to integrate W&B into a psuedo machine learning training loop, let’s track a machine learning experiment using a basic PyTorch neural network. The following code will also upload model checkpoints to W&B that you can then share with other teams in your organization.

Track a machine learning experiment using Pytorch

The following code cell defines and trains a simple MNIST classifier. During training, you will see W&B prints out URLs. Click on the project page link to see your results stream in live to a W&B project.

W&B runs automatically log metrics, system information, hyperparameters, terminal output and you’ll see an interactive table with model inputs and outputs.

Set up PyTorch Dataloader

The following cell defines some useful functions that we will need to train our machine learning model. The functions themselves are not unique to W&B so we’ll not cover them in detail here. See the PyTorch documentation for more information on how to define forward and backward training loop, how to use PyTorch DataLoaders to load data in for training, and how define PyTorch models using the torch.nn.Sequential Class.

# @title
import torch, torchvision
import torch.nn as nn
from torchvision.datasets import MNIST
import torchvision.transforms as T

MNIST.mirrors = [
    mirror for mirror in MNIST.mirrors if "http://yann.lecun.com/" not in mirror
]

device = "cuda:0" if torch.cuda.is_available() else "cpu"


def get_dataloader(is_train, batch_size, slice=5):
    "Get a training dataloader"
    full_dataset = MNIST(
        root=".", train=is_train, transform=T.ToTensor(), download=True
    )
    sub_dataset = torch.utils.data.Subset(
        full_dataset, indices=range(0, len(full_dataset), slice)
    )
    loader = torch.utils.data.DataLoader(
        dataset=sub_dataset,
        batch_size=batch_size,
        shuffle=True if is_train else False,
        pin_memory=True,
        num_workers=2,
    )
    return loader


def get_model(dropout):
    "A simple model"
    model = nn.Sequential(
        nn.Flatten(),
        nn.Linear(28 * 28, 256),
        nn.BatchNorm1d(256),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(256, 10),
    ).to(device)
    return model


def validate_model(model, valid_dl, loss_func, log_images=False, batch_idx=0):
    "Compute performance of the model on the validation dataset and log a wandb.Table"
    model.eval()
    val_loss = 0.0
    with torch.inference_mode():
        correct = 0
        for i, (images, labels) in enumerate(valid_dl):
            images, labels = images.to(device), labels.to(device)

            # Forward pass ➡
            outputs = model(images)
            val_loss += loss_func(outputs, labels) * labels.size(0)

            # Compute accuracy and accumulate
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()

            # Log one batch of images to the dashboard, always same batch_idx.
            if i == batch_idx and log_images:
                log_image_table(images, predicted, labels, outputs.softmax(dim=1))
    return val_loss / len(valid_dl.dataset), correct / len(valid_dl.dataset)

Create a teble to compare the predicted values versus the true value

The following cell is unique to W&B, so let’s go over it.

In the cell we define a function called log_image_table. Though technically, optional, this function creates a W&B Table object. We will use the table object to create a table that shows what the model predicted for each image.

More specifically, each row will conists of the image fed to the model, along with predicted value and the actual value (label).

def log_image_table(images, predicted, labels, probs):
    "Log a wandb.Table with (img, pred, target, scores)"
    # Create a wandb Table to log images, labels and predictions to
    table = wandb.Table(
        columns=["image", "pred", "target"] + [f"score_{i}" for i in range(10)]
    )
    for img, pred, targ, prob in zip(
        images.to("cpu"), predicted.to("cpu"), labels.to("cpu"), probs.to("cpu")
    ):
        table.add_data(wandb.Image(img[0].numpy() * 255), pred, targ, *prob.numpy())
    wandb.log({"predictions_table": table}, commit=False)

Train your model and upload checkpoints

The following code trains and saves model checkpoints to your project. Use model checkpoints like you normally would to assess how the model performed during training.

W&B also makes it easy to share your saved models and model checkpoints with other members of your team or organization. To learn how to share your model and model checkpoints with members outside of your team, see W&B Registry.

# Launch 3 experiments, trying different dropout rates
for _ in range(3):
    # initialise a wandb run
    wandb.init(
        project="pytorch-intro",
        config={
            "epochs": 5,
            "batch_size": 128,
            "lr": 1e-3,
            "dropout": random.uniform(0.01, 0.80),
        },
    )

    # Copy your config
    config = wandb.config

    # Get the data
    train_dl = get_dataloader(is_train=True, batch_size=config.batch_size)
    valid_dl = get_dataloader(is_train=False, batch_size=2 * config.batch_size)
    n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)

    # A simple MLP model
    model = get_model(config.dropout)

    # Make the loss and optimizer
    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    # Training
    example_ct = 0
    step_ct = 0
    for epoch in range(config.epochs):
        model.train()
        for step, (images, labels) in enumerate(train_dl):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            train_loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

            example_ct += len(images)
            metrics = {
                "train/train_loss": train_loss,
                "train/epoch": (step + 1 + (n_steps_per_epoch * epoch))
                / n_steps_per_epoch,
                "train/example_ct": example_ct,
            }

            if step + 1 < n_steps_per_epoch:
                # Log train metrics to wandb
                wandb.log(metrics)

            step_ct += 1

        val_loss, accuracy = validate_model(
            model, valid_dl, loss_func, log_images=(epoch == (config.epochs - 1))
        )

        # Log train and validation metrics to wandb
        val_metrics = {"val/val_loss": val_loss, "val/val_accuracy": accuracy}
        wandb.log({**metrics, **val_metrics})

        # Save the model checkpoint to wandb
        torch.save(model, "my_model.pt")
        wandb.log_model(
            "./my_model.pt",
            "my_mnist_model",
            aliases=[f"epoch-{epoch+1}_dropout-{round(wandb.config.dropout, 4)}"],
        )

        print(
            f"Epoch: {epoch+1}, Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}, Accuracy: {accuracy:.2f}"
        )

    # If you had a test set, this is how you could log it as a Summary metric
    wandb.summary["test_accuracy"] = 0.8

    # Close your wandb run
    wandb.finish()

You have now trained your first model using W&B. Click on one of the links above to see your metrics and see your saved model checkpoints in the Artifacts tab in the W&B App UI

(Optional) Set up a W&B Alert

Create a W&B Alerts to send alerts to your Slack or email from your Python code.

There are 2 steps to follow the first time you’d like to send a Slack or email alert, triggered from your code:

  1. Turn on Alerts in your W&B User Settings
  2. Add wandb.alert() to your code. For example:
wandb.alert(title="Low accuracy", text=f"Accuracy is below the acceptable threshold")

The following cell shows a minimal example below to see how to use wandb.alert

# Start a wandb run
wandb.init(project="pytorch-intro")

# Simulating a model training loop
acc_threshold = 0.3
for training_step in range(1000):

    # Generate a random number for accuracy
    accuracy = round(random.random() + random.random(), 3)
    print(f"Accuracy is: {accuracy}, {acc_threshold}")

    # Log accuracy to wandb
    wandb.log({"Accuracy": accuracy})

    # If the accuracy is below the threshold, fire a W&B Alert and stop the run
    if accuracy <= acc_threshold:
        # Send the wandb Alert
        wandb.alert(
            title="Low Accuracy",
            text=f"Accuracy {accuracy} at step {training_step} is below the acceptable theshold, {acc_threshold}",
        )
        print("Alert triggered")
        break

# Mark the run as finished (useful in Jupyter notebooks)
wandb.finish()

You can find the full docs for W&B Alerts here.

Next steps

The next tutorial you will learn how to do hyperparameter optimization using W&B Sweeps: Hyperparameters sweeps using PyTorch

2 - Visualize predictions with tables

This covers how to track, visualize, and compare model predictions over the course of training, using PyTorch on MNIST data.

You will learn how to:

  1. Log metrics, images, text, etc. to a wandb.Table() during model training or evaluation
  2. View, sort, filter, group, join, interactively query, and explore these tables
  3. Compare model predictions or results: dynamically across specific images, hyperparameters/model versions, or time steps.

Examples

Compare predicted scores for specific images

Live example: compare predictions after 1 vs 5 epochs of training →

1 epoch vs 5 epochs of training

The histograms compare per-class scores between the two models. The top green bar in each histogram represents model “CNN-2, 1 epoch” (id 0), which only trained for 1 epoch. The bottom purple bar represents model “CNN-2, 5 epochs” (id 1), which trained for 5 epochs. The images are filtered to cases where the models disagree. For example, in the first row, the “4” gets high scores across all the possible digits after 1 epoch, but after 5 epochs it scores highest on the correct label and very low on the rest.

Focus on top errors over time

Live example →

See incorrect predictions (filter to rows where “guess” != “truth”) on the full test data. Note that there are 229 wrong guesses after 1 training epoch, but only 98 after 5 epochs.

side by side, 1 vs 5 epochs of training

Compare model performance and find patterns

See full detail in a live example →

Filter out correct answers, then group by the guess to see examples of misclassified images and the underlying distribution of true labels—for two models side-by-side. A model variant with 2X the layer sizes and learning rate is on the left, and the baseline is on the right. Note that the baseline makes slightly more mistakes for each guessed class.

grouped errors for baseline vs double variant

Sign up or login

Sign up or login to W&B to see and interact with your experiments in the browser.

In this example we’re using Google Colab as a convenient hosted environment, but you can run your own training scripts from anywhere and visualize metrics with W&B’s experiment tracking tool.

!pip install wandb -qqq

log to your account


import wandb
wandb.login()

WANDB_PROJECT = "mnist-viz"

0. Setup

Install dependencies, download MNIST, and create train and test datasets using PyTorch.

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T 
import torch.nn.functional as F


device = "cuda:0" if torch.cuda.is_available() else "cpu"

# create train and test dataloaders
def get_dataloader(is_train, batch_size, slice=5):
    "Get a training dataloader"
    ds = torchvision.datasets.MNIST(root=".", train=is_train, transform=T.ToTensor(), download=True)
    loader = torch.utils.data.DataLoader(dataset=ds, 
                                         batch_size=batch_size, 
                                         shuffle=True if is_train else False, 
                                         pin_memory=True, num_workers=2)
    return loader

1. Define the model and training schedule

  • Set the number of epochs to run, where each epoch consists of a training step and a validation (test) step. Optionally configure the amount of data to log per test step. Here the number of batches and number of images per batch to visualize are set low to simplify the demo.
  • Define a simple convolutional neural net (following pytorch-tutorial code).
  • Load in train and test sets using PyTorch
# Number of epochs to run
# Each epoch includes a training step and a test step, so this sets
# the number of tables of test predictions to log
EPOCHS = 1

# Number of batches to log from the test data for each test step
# (default set low to simplify demo)
NUM_BATCHES_TO_LOG = 10 #79

# Number of images to log per test batch
# (default set low to simplify demo)
NUM_IMAGES_PER_BATCH = 32 #128

# training configuration and hyperparameters
NUM_CLASSES = 10
BATCH_SIZE = 32
LEARNING_RATE = 0.001
L1_SIZE = 32
L2_SIZE = 64
# changing this may require changing the shape of adjacent layers
CONV_KERNEL_SIZE = 5

# define a two-layer convolutional neural network
class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, L1_SIZE, CONV_KERNEL_SIZE, stride=1, padding=2),
            nn.BatchNorm2d(L1_SIZE),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(L1_SIZE, L2_SIZE, CONV_KERNEL_SIZE, stride=1, padding=2),
            nn.BatchNorm2d(L2_SIZE),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*L2_SIZE, NUM_CLASSES)
        self.softmax = nn.Softmax(NUM_CLASSES)

    def forward(self, x):
        # uncomment to see the shape of a given layer:
        #print("x: ", x.size())
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

train_loader = get_dataloader(is_train=True, batch_size=BATCH_SIZE)
test_loader = get_dataloader(is_train=False, batch_size=2*BATCH_SIZE)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

2. Run training and log test predictions

For every epoch, run a training step and a test step. For each test step, create a wandb.Table() in which to store test predictions. These can be visualized, dynamically queried, and compared side by side in your browser.

# ✨ W&B: Initialize a new run to track this model's training
wandb.init(project="table-quickstart")

# ✨ W&B: Log hyperparameters using config
cfg = wandb.config
cfg.update({"epochs" : EPOCHS, "batch_size": BATCH_SIZE, "lr" : LEARNING_RATE,
            "l1_size" : L1_SIZE, "l2_size": L2_SIZE,
            "conv_kernel" : CONV_KERNEL_SIZE,
            "img_count" : min(10000, NUM_IMAGES_PER_BATCH*NUM_BATCHES_TO_LOG)})

# define model, loss, and optimizer
model = ConvNet(NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# convenience funtion to log predictions for a batch of test images
def log_test_predictions(images, labels, outputs, predicted, test_table, log_counter):
  # obtain confidence scores for all classes
  scores = F.softmax(outputs.data, dim=1)
  log_scores = scores.cpu().numpy()
  log_images = images.cpu().numpy()
  log_labels = labels.cpu().numpy()
  log_preds = predicted.cpu().numpy()
  # adding ids based on the order of the images
  _id = 0
  for i, l, p, s in zip(log_images, log_labels, log_preds, log_scores):
    # add required info to data table:
    # id, image pixels, model's guess, true label, scores for all classes
    img_id = str(_id) + "_" + str(log_counter)
    test_table.add_data(img_id, wandb.Image(i), p, l, *s)
    _id += 1
    if _id == NUM_IMAGES_PER_BATCH:
      break

# train the model
total_step = len(train_loader)
for epoch in range(EPOCHS):
    # training step
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        # forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        # backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
  
        # ✨ W&B: Log loss over training steps, visualized in the UI live
        wandb.log({"loss" : loss})
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                .format(epoch+1, EPOCHS, i+1, total_step, loss.item()))
            

    # ✨ W&B: Create a Table to store predictions for each test step
    columns=["id", "image", "guess", "truth"]
    for digit in range(10):
      columns.append("score_" + str(digit))
    test_table = wandb.Table(columns=columns)

    # test the model
    model.eval()
    log_counter = 0
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            if log_counter < NUM_BATCHES_TO_LOG:
              log_test_predictions(images, labels, outputs, predicted, test_table, log_counter)
              log_counter += 1
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        acc = 100 * correct / total
        # ✨ W&B: Log accuracy across training epochs, to visualize in the UI
        wandb.log({"epoch" : epoch, "acc" : acc})
        print('Test Accuracy of the model on the 10000 test images: {} %'.format(acc))

    # ✨ W&B: Log predictions table to wandb
    wandb.log({"test_predictions" : test_table})

# ✨ W&B: Mark the run as complete (useful for multi-cell notebook)
wandb.finish()

What’s next?

The next tutorial, you will learn how to optimize hyperparameters using W&B Sweeps:

👉 Optimize Hyperparameters

3 - Tune hyperparameters with sweeps

Finding a machine learning model that meets your desired metric (such as model accuracy) is normally a redundant task that can take multiple iterations. To make matters worse, it might be unclear which hyperparameter combinations to use for a given training run.

Use W&B Sweeps to create an organized and efficient way to automatically search through combinations of hyperparameter values such as the learning rate, batch size, number of hidden layers, optimizer type and more to find values that optimize your model based on your desired metric.

In this tutorial you will create a hyperparameter search with W&B PyTorch integration. Follow along with a video tutorial.

Sweeps: An Overview

Running a hyperparameter sweep with Weights & Biases is very easy. There are just 3 simple steps:

  1. Define the sweep: we do this by creating a dictionary or a YAML file that specifies the parameters to search through, the search strategy, the optimization metric et all.

  2. Initialize the sweep: with one line of code we initialize the sweep and pass in the dictionary of sweep configurations: sweep_id = wandb.sweep(sweep_config)

  3. Run the sweep agent: also accomplished with one line of code, we call wandb.agent() and pass the sweep_id to run, along with a function that defines your model architecture and trains it: wandb.agent(sweep_id, function=train)

Before you get started

Install W&B and import the W&B Python SDK into your notebook:

  1. Install with !pip install:
!pip install wandb -Uq
  1. Import W&B:
import wandb
  1. Log in to W&B and provide your API key when prompted:
wandb.login()

Step 1️: Define a sweep

A W&B Sweep combines a strategy for trying numerous hyperparameter values with the code that evaluates them. Before you start a sweep, you must define your sweep strategy with a sweep configuration.

Pick a search method

First, specify a hyperparameter search method within your configuration dictionary. There are three hyperparameter search strategies to choose from: grid, random, and Bayesian search.

For this tutorial, you will use a random search. Within your notebook, create a dictionary and specify random for the method key.

sweep_config = {
    'method': 'random'
    }

Specify a metric that you want to optimize for. You do not need to specify the metric and goal for sweeps that use random search method. However, it is good practice to keep track of your sweep goals because you can refer to it at a later time.

metric = {
    'name': 'loss',
    'goal': 'minimize'   
    }

sweep_config['metric'] = metric

Specify hyperparameters to search through

Now that you have a search method specified in your sweep configuration, specify the hyperparameters you want to search over.

To do this, specify one or more hyperparameter names to the parameter key and specify one or more hyperparameter values for the value key.

The values you search through for a given hyperparamter depend on the type of hyperparameter you are investigating.

For example, if you choose a machine learning optimizer, you must specify one or more finite optimizer names such as the Adam optimizer and stochastic gradient dissent.

parameters_dict = {
    'optimizer': {
        'values': ['adam', 'sgd']
        },
    'fc_layer_size': {
        'values': [128, 256, 512]
        },
    'dropout': {
          'values': [0.3, 0.4, 0.5]
        },
    }

sweep_config['parameters'] = parameters_dict

Sometimes you want to track a hyperparameter, but not vary its value. In this case, add the hyperparameter to your sweep configuration and specify the exact value that you want to use. For example, in the following code cell, epochs is set to 1.

parameters_dict.update({
    'epochs': {
        'value': 1}
    })

For a random search, all the values of a parameter are equally likely to be chosen on a given run.

Alternatively, you can specify a named distribution, plus its parameters, like the mean mu and standard deviation sigma of a normal distribution.

parameters_dict.update({
    'learning_rate': {
        # a flat distribution between 0 and 0.1
        'distribution': 'uniform',
        'min': 0,
        'max': 0.1
      },
    'batch_size': {
        # integers between 32 and 256
        # with evenly-distributed logarithms 
        'distribution': 'q_log_uniform_values',
        'q': 8,
        'min': 32,
        'max': 256,
      }
    })

When we’re finished, sweep_config is a nested dictionary that specifies exactly which parameters we’re interested in trying and the method we’re going to use to try them.

Let’s see how the sweep configuration looks like:

import pprint
pprint.pprint(sweep_config)

For a full list of configuration options, see Sweep configuration options.

Step 2️: Initialize the Sweep

Once you’ve defined the search strategy, it’s time to set up something to implement it.

W&B uses a Sweep Controller to manage sweeps on the cloud or locally across one or more machines. For this tutorial, you will use a sweep controller managed by W&B.

While sweep controllers manage sweeps, the component that actually executes a sweep is known as a sweep agent.

Within your notebook, you can activate a sweep controller with the wandb.sweep method. Pass your sweep configuration dictionary you defined earlier to the sweep_config field:

sweep_id = wandb.sweep(sweep_config, project="pytorch-sweeps-demo")

The wandb.sweep function returns a sweep_id that you will use at a later step to activate your sweep.

For more information on how to create W&B Sweeps in a terminal, see the W&B Sweep walkthrough.

Step 3: Define your machine learning code

Before you execute the sweep, define the training procedure that uses the hyperparameter values you want to try. The key to integrating W&B Sweeps into your training code is to ensure that, for each training experiment, that your training logic can access the hyperparameter values you defined in your sweep configuration.

In the proceeding code example, the helper functions build_dataset, build_network, build_optimizer, and train_epoch access the sweep hyperparameter configuration dictionary.

Run the proceeding machine learning training code in your notebook. The functions define a basic fully connected neural network in PyTorch.

import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(config=None):
    # Initialize a new wandb run
    with wandb.init(config=config):
        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller
        config = wandb.config

        loader = build_dataset(config.batch_size)
        network = build_network(config.fc_layer_size, config.dropout)
        optimizer = build_optimizer(network, config.optimizer, config.learning_rate)

        for epoch in range(config.epochs):
            avg_loss = train_epoch(network, loader, optimizer)
            wandb.log({"loss": avg_loss, "epoch": epoch})           

Within the train function, you will notice the following W&B Python SDK methods:

  • wandb.init(): Initialize a new W&B run. Each run is a single execution of the training function.
  • wandb.config: Pass sweep configuration with the hyperparameters you want to experiment with.
  • wandb.log(): Log the training loss for each epoch.

The proceeding cell defines four functions: build_dataset, build_network, build_optimizer, and train_epoch. These functions are a standard part of a basic PyTorch pipeline, and their implementation is unaffected by the use of W&B.

def build_dataset(batch_size):
   
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307,), (0.3081,))])
    # download MNIST training dataset
    dataset = datasets.MNIST(".", train=True, download=True,
                             transform=transform)
    sub_dataset = torch.utils.data.Subset(
        dataset, indices=range(0, len(dataset), 5))
    loader = torch.utils.data.DataLoader(sub_dataset, batch_size=batch_size)

    return loader


def build_network(fc_layer_size, dropout):
    network = nn.Sequential(  # fully connected, single hidden layer
        nn.Flatten(),
        nn.Linear(784, fc_layer_size), nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(fc_layer_size, 10),
        nn.LogSoftmax(dim=1))

    return network.to(device)
        

def build_optimizer(network, optimizer, learning_rate):
    if optimizer == "sgd":
        optimizer = optim.SGD(network.parameters(),
                              lr=learning_rate, momentum=0.9)
    elif optimizer == "adam":
        optimizer = optim.Adam(network.parameters(),
                               lr=learning_rate)
    return optimizer


def train_epoch(network, loader, optimizer):
    cumu_loss = 0
    for _, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        # ➡ Forward pass
        loss = F.nll_loss(network(data), target)
        cumu_loss += loss.item()

        # ⬅ Backward pass + weight update
        loss.backward()
        optimizer.step()

        wandb.log({"batch loss": loss.item()})

    return cumu_loss / len(loader)

For more details on instrumenting W&B with PyTorch, see this Colab.

Step 4: Activate sweep agents

Now that you have your sweep configuration defined and a training script that can utilize those hyperparameter in an interactive way, you are ready to activate a sweep agent. Sweep agents are responsible for running an experiment with a set of hyperparameter values that you defined in your sweep configuration.

Create sweep agents with the wandb.agent method. Provide the following:

  1. The sweep the agent is a part of (sweep_id)
  2. The function the sweep is supposed to run. In this example, the sweep will use the train function.
  3. (optionally) How many configs to ask the sweep controller for (count)

The proceeding cell activates a sweep agent that runs the training function (train) 5 times:

wandb.agent(sweep_id, train, count=5)

For more information on how to create W&B Sweeps in a terminal, see the W&B Sweep walkthrough.

Visualize Sweep Results

Parallel Coordinates Plot

This plot maps hyperparameter values to model metrics. It’s useful for honing in on combinations of hyperparameters that led to the best model performance.

Hyperparameter Importance Plot

The hyperparameter importance plot surfaces which hyperparameters were the best predictors of your metrics. We report feature importance (from a random forest model) and correlation (implicitly a linear model).

These visualizations can help you save both time and resources running expensive hyperparameter optimizations by honing in on the parameters (and value ranges) that are the most important, and thereby worthy of further exploration.

Learn more about W&B Sweeps

We created a simple training script and a few flavors of sweep configs for you to play with. We highly encourage you to give these a try.

That repo also has examples to help you try more advanced sweep features like Bayesian Hyperband, and Hyperopt.

4 - Track models and datasets

In this notebook, we’ll show you how to track your ML experiment pipelines using W&B Artifacts.

Follow along with a video tutorial.

About artifacts

An artifact, like a Greek amphora, is a produced object – the output of a process. In ML, the most important artifacts are datasets and models.

And, like the Cross of Coronado, these important artifacts belong in a museum. That is, they should be cataloged and organized so that you, your team, and the ML community at large can learn from them. After all, those who don’t track training are doomed to repeat it.

Using our Artifacts API, you can log Artifacts as outputs of W&B Runs or use Artifacts as input to Runs, as in this diagram, where a training run takes in a dataset and produces a model.

Since one run can use another run’s output as an input, Artifacts and Runs together form a directed graph (a bipartite DAG, with nodes for Artifacts and Runs and arrows that connect a Run to the Artifacts it consumes or produces.

Use artifacts to track models and datatsets

Install and Import

Artifacts are part of our Python library, starting with version 0.9.2.

Like most parts of the ML Python stack, it’s available via pip.

# Compatible with wandb version 0.9.2+
!pip install wandb -qqq
!apt install tree
import os
import wandb

Log a Dataset

First, let’s define some Artifacts.

This example is based off of this PyTorch “Basic MNIST Example”, but could just as easily have been done in TensorFlow, in any other framework, or in pure Python.

We start with the Datasets:

  • a training set, for choosing the parameters,
  • a validation set, for choosing the hyperparameters,
  • a testing set, for evaluating the final model

The first cell below defines these three datasets.

import random 

import torch
import torchvision
from torch.utils.data import TensorDataset
from tqdm.auto import tqdm

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Data parameters
num_classes = 10
input_shape = (1, 28, 28)

# drop slow mirror from list of MNIST mirrors
torchvision.datasets.MNIST.mirrors = [mirror for mirror in torchvision.datasets.MNIST.mirrors
                                      if not mirror.startswith("http://yann.lecun.com")]

def load(train_size=50_000):
    """
    # Load the data
    """

    # the data, split between train and test sets
    train = torchvision.datasets.MNIST("./", train=True, download=True)
    test = torchvision.datasets.MNIST("./", train=False, download=True)
    (x_train, y_train), (x_test, y_test) = (train.data, train.targets), (test.data, test.targets)

    # split off a validation set for hyperparameter tuning
    x_train, x_val = x_train[:train_size], x_train[train_size:]
    y_train, y_val = y_train[:train_size], y_train[train_size:]

    training_set = TensorDataset(x_train, y_train)
    validation_set = TensorDataset(x_val, y_val)
    test_set = TensorDataset(x_test, y_test)

    datasets = [training_set, validation_set, test_set]

    return datasets

This sets up a pattern we’ll see repeated in this example: the code to log the data as an Artifact is wrapped around the code for producing that data. In this case, the code for loading the data is separated out from the code for load_and_logging the data.

This is good practice.

In order to log these datasets as Artifacts, we just need to

  1. create a Run with wandb.init, (L4)
  2. create an Artifact for the dataset (L10), and
  3. save and log the associated files (L20, L23).

Check out the example the code cell below and then expand the sections afterwards for more details.

def load_and_log():

    # 🚀 start a run, with a type to label it and a project it can call home
    with wandb.init(project="artifacts-example", job_type="load-data") as run:
        
        datasets = load()  # separate code for loading the datasets
        names = ["training", "validation", "test"]

        # 🏺 create our Artifact
        raw_data = wandb.Artifact(
            "mnist-raw", type="dataset",
            description="Raw MNIST dataset, split into train/val/test",
            metadata={"source": "torchvision.datasets.MNIST",
                      "sizes": [len(dataset) for dataset in datasets]})

        for name, data in zip(names, datasets):
            # 🐣 Store a new file in the artifact, and write something into its contents.
            with raw_data.new_file(name + ".pt", mode="wb") as file:
                x, y = data.tensors
                torch.save((x, y), file)

        # ✍️ Save the artifact to W&B.
        run.log_artifact(raw_data)

load_and_log()

wandb.init

When we make the Run that’s going to produce the Artifacts, we need to state which project it belongs to.

Depending on your workflow, a project might be as big as car-that-drives-itself or as small as iterative-architecture-experiment-117.

Rule of 👍: if you can, keep all of the Runs that share Artifacts inside a single project. This keeps things simple, but don’t worry – Artifacts are portable across projects.

To help keep track of all the different kinds of jobs you might run, it’s useful to provide a job_type when making Runs. This keeps the graph of your Artifacts nice and tidy.

Rule of 👍: the job_type should be descriptive and correspond to a single step of your pipeline. Here, we separate out loading data from preprocessing data.

wandb.Artifact

To log something as an Artifact, we have to first make an Artifact object.

Every Artifact has a name – that’s what the first argument sets.

Rule of 👍: the name should be descriptive, but easy to remember and type – we like to use names that are hyphen-separated and correspond to variable names in the code.

It also has a type. Just like job_types for Runs, this is used for organizing the graph of Runs and Artifacts.

Rule of 👍: the type should be simple: more like dataset or model than mnist-data-YYYYMMDD.

You can also attach a description and some metadata, as a dictionary. The metadata just needs to be serializable to JSON.

Rule of 👍: the metadata should be as descriptive as possible.

artifact.new_file and run.log_artifact

Once we’ve made an Artifact object, we need to add files to it.

You read that right: files with an s. Artifacts are structured like directories, with files and sub-directories.

Rule of 👍: whenever it makes sense to do so, split the contents of an Artifact up into multiple files. This will help if it comes time to scale.

We use the new_file method to simultaneously write the file and attach it to the Artifact. Below, we’ll use the add_file method, which separates those two steps.

Once we’ve added all of our files, we need to log_artifact to wandb.ai.

You’ll notice some URLs appeared in the output, including one for the Run page. That’s where you can view the results of the Run, including any Artifacts that got logged.

We’ll see some examples that make better use of the other components of the Run page below.

Use a Logged Dataset Artifact

Artifacts in W&B, unlike artifacts in museums, are designed to be used, not just stored.

Let’s see what that looks like.

The cell below defines a pipeline step that takes in a raw dataset and uses it to produce a preprocessed dataset: normalized and shaped correctly.

Notice again that we split out the meat of the code, preprocess, from the code that interfaces with wandb.

def preprocess(dataset, normalize=True, expand_dims=True):
    """
    ## Prepare the data
    """
    x, y = dataset.tensors

    if normalize:
        # Scale images to the [0, 1] range
        x = x.type(torch.float32) / 255

    if expand_dims:
        # Make sure images have shape (1, 28, 28)
        x = torch.unsqueeze(x, 1)
    
    return TensorDataset(x, y)

Now for the code that instruments this preprocess step with wandb.Artifact logging.

Note that the example below both uses an Artifact, which is new, and logs it, which is the same as the last step. Artifacts are both the inputs and the outputs of Runs.

We use a new job_type, preprocess-data, to make it clear that this is a different kind of job from the previous one.

def preprocess_and_log(steps):

    with wandb.init(project="artifacts-example", job_type="preprocess-data") as run:

        processed_data = wandb.Artifact(
            "mnist-preprocess", type="dataset",
            description="Preprocessed MNIST dataset",
            metadata=steps)
         
        # ✔️ declare which artifact we'll be using
        raw_data_artifact = run.use_artifact('mnist-raw:latest')

        # 📥 if need be, download the artifact
        raw_dataset = raw_data_artifact.download()
        
        for split in ["training", "validation", "test"]:
            raw_split = read(raw_dataset, split)
            processed_dataset = preprocess(raw_split, **steps)

            with processed_data.new_file(split + ".pt", mode="wb") as file:
                x, y = processed_dataset.tensors
                torch.save((x, y), file)

        run.log_artifact(processed_data)


def read(data_dir, split):
    filename = split + ".pt"
    x, y = torch.load(os.path.join(data_dir, filename))

    return TensorDataset(x, y)

One thing to notice here is that the steps of the preprocessing are saved with the preprocessed_data as metadata.

If you’re trying to make your experiments reproducible, capturing lots of metadata is a good idea.

Also, even though our dataset is a “large artifact”, the download step is done in much less than a second.

Expand the markdown cell below for details.

steps = {"normalize": True,
         "expand_dims": True}

preprocess_and_log(steps)

run.use_artifact

These steps are simpler. The consumer just needs to know the name of the Artifact, plus a bit more.

That “bit more” is the alias of the particular version of the Artifact you want.

By default, the last version to be uploaded is tagged latest. Otherwise, you can pick older versions with v0/v1, etc., or you can provide your own aliases, like best or jit-script. Just like Docker Hub tags, aliases are separated from names with :, so the Artifact we want is mnist-raw:latest.

Rule of 👍: Keep aliases short and sweet. Use custom aliases like latest or best when you want an Artifact that satisifies some property

artifact.download

Now, you may be worrying about the download call. If we download another copy, won’t that double the burden on memory?

Don’t worry friend. Before we actually download anything, we check to see if the right version is available locally. This uses the same technology that underlies torrenting and version control with git: hashing.

As Artifacts are created and logged, a folder called artifacts in the working directory will start to fill with sub-directories, one for each Artifact. Check out its contents with !tree artifacts:

!tree artifacts

The Artifacts page

Now that we’ve logged and used an Artifact, let’s check out the Artifacts tab on the Run page.

Navigate to the Run page URL from the wandb output and select the “Artifacts” tab from the left sidebar (it’s the one with the database icon, which looks like three hockey pucks stacked on top of one another).

Click a row in either the Input Artifacts table or in the Output Artifacts table, then check out the tabs (Overview, Metadata) to see everything logged about the Artifact.

We particularly like the Graph View. By default, it shows a graph with the types of Artifacts and the job_types of Run as the two types of nodes, with arrows to represent consumption and production.

Log a Model

That’s enough to see how the API for Artifacts works, but let’s follow this example through to the end of the pipeline so we can see how Artifacts can improve your ML workflow.

This first cell here builds a DNN model in PyTorch – a really simple ConvNet.

We’ll start by just initializing the model, not training it. That way, we can repeat the training while keeping everything else constant.

from math import floor

import torch.nn as nn

class ConvNet(nn.Module):
    def __init__(self, hidden_layer_sizes=[32, 64],
                  kernel_sizes=[3],
                  activation="ReLU",
                  pool_sizes=[2],
                  dropout=0.5,
                  num_classes=num_classes,
                  input_shape=input_shape):
      
        super(ConvNet, self).__init__()

        self.layer1 = nn.Sequential(
              nn.Conv2d(in_channels=input_shape[0], out_channels=hidden_layer_sizes[0], kernel_size=kernel_sizes[0]),
              getattr(nn, activation)(),
              nn.MaxPool2d(kernel_size=pool_sizes[0])
        )
        self.layer2 = nn.Sequential(
              nn.Conv2d(in_channels=hidden_layer_sizes[0], out_channels=hidden_layer_sizes[-1], kernel_size=kernel_sizes[-1]),
              getattr(nn, activation)(),
              nn.MaxPool2d(kernel_size=pool_sizes[-1])
        )
        self.layer3 = nn.Sequential(
              nn.Flatten(),
              nn.Dropout(dropout)
        )

        fc_input_dims = floor((input_shape[1] - kernel_sizes[0] + 1) / pool_sizes[0]) # layer 1 output size
        fc_input_dims = floor((fc_input_dims - kernel_sizes[-1] + 1) / pool_sizes[-1]) # layer 2 output size
        fc_input_dims = fc_input_dims*fc_input_dims*hidden_layer_sizes[-1] # layer 3 output size

        self.fc = nn.Linear(fc_input_dims, num_classes)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.fc(x)
        return x

Here, we’re using W&B to track the run, and so using the wandb.config object to store all of the hyperparameters.

The dictionary version of that config object is a really useful piece of metadata, so make sure to include it.

def build_model_and_log(config):
    with wandb.init(project="artifacts-example", job_type="initialize", config=config) as run:
        config = wandb.config
        
        model = ConvNet(**config)

        model_artifact = wandb.Artifact(
            "convnet", type="model",
            description="Simple AlexNet style CNN",
            metadata=dict(config))

        torch.save(model.state_dict(), "initialized_model.pth")
        # ➕ another way to add a file to an Artifact
        model_artifact.add_file("initialized_model.pth")

        wandb.save("initialized_model.pth")

        run.log_artifact(model_artifact)

model_config = {"hidden_layer_sizes": [32, 64],
                "kernel_sizes": [3],
                "activation": "ReLU",
                "pool_sizes": [2],
                "dropout": 0.5,
                "num_classes": 10}

build_model_and_log(model_config)

artifact.add_file

Instead of simultaneously writing a new_file and adding it to the Artifact, as in the dataset logging examples, we can also write files in one step (here, torch.save) and then add them to the Artifact in another.

Rule of 👍: use new_file when you can, to prevent duplication.

Use a Logged Model Artifact

Just like we could call use_artifact on a dataset, we can call it on our initialized_model to use it in another Run.

This time, let’s train the model.

For more details, check out our Colab on instrumenting W&B with PyTorch.

import torch.nn.functional as F

def train(model, train_loader, valid_loader, config):
    optimizer = getattr(torch.optim, config.optimizer)(model.parameters())
    model.train()
    example_ct = 0
    for epoch in range(config.epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()

            example_ct += len(data)

            if batch_idx % config.batch_log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0%})]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    batch_idx / len(train_loader), loss.item()))
                
                train_log(loss, example_ct, epoch)

        # evaluate the model on the validation set at each epoch
        loss, accuracy = test(model, valid_loader)  
        test_log(loss, accuracy, example_ct, epoch)

    
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum')  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum()

    test_loss /= len(test_loader.dataset)

    accuracy = 100. * correct / len(test_loader.dataset)
    
    return test_loss, accuracy


def train_log(loss, example_ct, epoch):
    loss = float(loss)

    # where the magic happens
    wandb.log({"epoch": epoch, "train/loss": loss}, step=example_ct)
    print(f"Loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}")
    

def test_log(loss, accuracy, example_ct, epoch):
    loss = float(loss)
    accuracy = float(accuracy)

    # where the magic happens
    wandb.log({"epoch": epoch, "validation/loss": loss, "validation/accuracy": accuracy}, step=example_ct)
    print(f"Loss/accuracy after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}/{accuracy:.3f}")

We’ll run two separate Artifact-producing Runs this time.

Once the first finishes training the model, the second will consume the trained-model Artifact by evaluateing its performance on the test_dataset.

Also, we’ll pull out the 32 examples on which the network gets the most confused – on which the categorical_crossentropy is highest.

This is a good way to diagnose issues with your dataset and your model.

def evaluate(model, test_loader):
    """
    ## Evaluate the trained model
    """

    loss, accuracy = test(model, test_loader)
    highest_losses, hardest_examples, true_labels, predictions = get_hardest_k_examples(model, test_loader.dataset)

    return loss, accuracy, highest_losses, hardest_examples, true_labels, predictions

def get_hardest_k_examples(model, testing_set, k=32):
    model.eval()

    loader = DataLoader(testing_set, 1, shuffle=False)

    # get the losses and predictions for each item in the dataset
    losses = None
    predictions = None
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = F.cross_entropy(output, target)
            pred = output.argmax(dim=1, keepdim=True)
            
            if losses is None:
                losses = loss.view((1, 1))
                predictions = pred
            else:
                losses = torch.cat((losses, loss.view((1, 1))), 0)
                predictions = torch.cat((predictions, pred), 0)

    argsort_loss = torch.argsort(losses, dim=0)

    highest_k_losses = losses[argsort_loss[-k:]]
    hardest_k_examples = testing_set[argsort_loss[-k:]][0]
    true_labels = testing_set[argsort_loss[-k:]][1]
    predicted_labels = predictions[argsort_loss[-k:]]

    return highest_k_losses, hardest_k_examples, true_labels, predicted_labels

These logging functions don’t add any new Artifact features, so we won’t comment on them: we’re just useing, downloading, and logging Artifacts.

from torch.utils.data import DataLoader

def train_and_log(config):

    with wandb.init(project="artifacts-example", job_type="train", config=config) as run:
        config = wandb.config

        data = run.use_artifact('mnist-preprocess:latest')
        data_dir = data.download()

        training_dataset =  read(data_dir, "training")
        validation_dataset = read(data_dir, "validation")

        train_loader = DataLoader(training_dataset, batch_size=config.batch_size)
        validation_loader = DataLoader(validation_dataset, batch_size=config.batch_size)
        
        model_artifact = run.use_artifact("convnet:latest")
        model_dir = model_artifact.download()
        model_path = os.path.join(model_dir, "initialized_model.pth")
        model_config = model_artifact.metadata
        config.update(model_config)

        model = ConvNet(**model_config)
        model.load_state_dict(torch.load(model_path))
        model = model.to(device)
 
        train(model, train_loader, validation_loader, config)

        model_artifact = wandb.Artifact(
            "trained-model", type="model",
            description="Trained NN model",
            metadata=dict(model_config))

        torch.save(model.state_dict(), "trained_model.pth")
        model_artifact.add_file("trained_model.pth")
        wandb.save("trained_model.pth")

        run.log_artifact(model_artifact)

    return model

    
def evaluate_and_log(config=None):
    
    with wandb.init(project="artifacts-example", job_type="report", config=config) as run:
        data = run.use_artifact('mnist-preprocess:latest')
        data_dir = data.download()
        testing_set = read(data_dir, "test")

        test_loader = torch.utils.data.DataLoader(testing_set, batch_size=128, shuffle=False)

        model_artifact = run.use_artifact("trained-model:latest")
        model_dir = model_artifact.download()
        model_path = os.path.join(model_dir, "trained_model.pth")
        model_config = model_artifact.metadata

        model = ConvNet(**model_config)
        model.load_state_dict(torch.load(model_path))
        model.to(device)

        loss, accuracy, highest_losses, hardest_examples, true_labels, preds = evaluate(model, test_loader)

        run.summary.update({"loss": loss, "accuracy": accuracy})

        wandb.log({"high-loss-examples":
            [wandb.Image(hard_example, caption=str(int(pred)) + "," +  str(int(label)))
             for hard_example, pred, label in zip(hardest_examples, preds, true_labels)]})
train_config = {"batch_size": 128,
                "epochs": 5,
                "batch_log_interval": 25,
                "optimizer": "Adam"}

model = train_and_log(train_config)
evaluate_and_log()

5 - Programmatic Workspaces

Organize and visualize your machine learning experiments more effectively by programmatically creating, managing, and customizing workspaces. You can define configurations, set panel layouts, and organize sections with the wandb-workspaces W&B library. You can load and modify workspaces by URL, use expressions to filter and group runs, and customize the appearances of runs.

wandb-workspaces is a Python library for programmatically creating and customizing W&B Workspaces and Reports.

In this tutorial you will see how to use wandb-workspaces to create and customize workspaces by defining configurations, set panel layouts, and organize sections.

How to use this notebook

  • Run each cell one at a time.
  • Copy and paste the URL that is printed after you run a cell to view the changes made to the workspace.

1. Install and import dependencies

# Install dependencies
!pip install wandb wandb-workspaces rich
# Import dependencies
import os
import wandb
import wandb_workspaces.workspaces as ws
import wandb_workspaces.reports.v2 as wr # We use the Reports API for adding panels

# Improve output formatting
%load_ext rich

2. Create a new project and workspace

For this tutorial we will create a new project so that we can experiment with the wandb_workspaces API:

Note: You can load an existing workspace using its unique Saved view URL. See the next code block to see how to do this.

# Initialize Weights & Biases and Login
wandb.login()

# Function to create a new project and log sample data
def create_project_and_log_data():
    project = "workspace-api-example"  # Default project name

    # Initialize a run to log some sample data
    with wandb.init(project=project, name="sample_run") as run:
        for step in range(100):
            wandb.log({
                "Step": step,
                "val_loss": 1.0 / (step + 1),
                "val_accuracy": step / 100.0,
                "train_loss": 1.0 / (step + 2),
                "train_accuracy": step / 110.0,
                "f1_score": step / 100.0,
                "recall": step / 120.0,
            })
    return project

# Create a new project and log data
project = create_project_and_log_data()
entity = wandb.Api().default_entity

(Optional) Load an existing project and workspace

Instead of creating a new project, you can load one of your own existing project and workspace. To do this, find the unique workspace URL and pass it to ws.Workspace.from_url as a string. The URL has the form https://wandb.ai/[SOURCE-ENTITY]/[SOURCE-USER]?nw=abc.

For example:

wandb.login()

workspace = ws.Workspace.from_url("https://wandb.ai/[SOURCE-ENTITY]/[SOURCE-USER]?nw=abc").

workspace = ws.Workspace(
    entity="NEW-ENTITY",
    project=NEW-PROJECT,
    name="NEW-SAVED-VIEW-NAME"
)

3. Programmatic workspace examples

Below are examples for using programmatic workspace features:

# See all available settings for workspaces, sections, and panels.
all_settings_objects = [x for x in dir(ws) if isinstance(getattr(ws, x), type)]
all_settings_objects

Create a workspace with saved view

This example demonstrates how to create a new workspace and populate it with sections and panels. Workspaces can be edited like regular Python objects, providing flexibility and ease of use.

def sample_workspace_saved_example(entity: str, project: str) -> str:
    workspace: ws.Workspace = ws.Workspace(
        name="Example W&B Workspace",
        entity=entity,
        project=project,
        sections=[
            ws.Section(
                name="Validation Metrics",
                panels=[
                    wr.LinePlot(x="Step", y=["val_loss"]),
                    wr.BarPlot(metrics=["val_accuracy"]),
                    wr.ScalarChart(metric="f1_score", groupby_aggfunc="mean"),
                ],
                is_open=True,
            ),
        ],
    )
    workspace.save()
    print("Sample Workspace saved.")
    return workspace.url

workspace_url: str = sample_workspace_saved_example(entity, project)

Load a workspace from a URL

Duplicate and customize workspaces without affecting the original setup. To do this, load an existing workspace and save it as a new view:

def save_new_workspace_view_example(url: str) -> None:
    workspace: ws.Workspace = ws.Workspace.from_url(url)

    workspace.name = "Updated Workspace Name"
    workspace.save()

    print(f"Workspace saved as new view.")

save_new_workspace_view_example(workspace_url)

Note that your workspace is now named “Updated Workspace Name”.

Basic settings

The following code shows how to create a workspace, add sections with panels, and configure settings for the workspace, individual sections, and panels:

# Function to create and configure a workspace with custom settings
def custom_settings_example(entity: str, project: str) -> None:
    workspace: ws.Workspace = ws.Workspace(name="An example workspace", entity=entity, project=project)
    workspace.sections = [
        ws.Section(
            name="Validation",
            panels=[
                wr.LinePlot(x="Step", y=["val_loss"]),
                wr.LinePlot(x="Step", y=["val_accuracy"]),
                wr.ScalarChart(metric="f1_score", groupby_aggfunc="mean"),
                wr.ScalarChart(metric="recall", groupby_aggfunc="mean"),
            ],
            is_open=True,
        ),
        ws.Section(
            name="Training",
            panels=[
                wr.LinePlot(x="Step", y=["train_loss"]),
                wr.LinePlot(x="Step", y=["train_accuracy"]),
            ],
            is_open=False,
        ),
    ]

    workspace.settings = ws.WorkspaceSettings(
        x_axis="Step",
        x_min=0,
        x_max=75,
        smoothing_type="gaussian",
        smoothing_weight=20.0,
        ignore_outliers=False,
        remove_legends_from_panels=False,
        tooltip_number_of_runs="default",
        tooltip_color_run_names=True,
        max_runs=20,
        point_visualization_method="bucketing",
        auto_expand_panel_search_results=False,
    )

    section = workspace.sections[0]
    section.panel_settings = ws.SectionPanelSettings(
        x_min=25,
        x_max=50,
        smoothing_type="none",
    )

    panel = section.panels[0]
    panel.title = "Validation Loss Custom Title"
    panel.title_x = "Custom x-axis title"

    workspace.save()
    print("Workspace with custom settings saved.")

# Run the function to create and configure the workspace
custom_settings_example(entity, project)

Note that you are now viewing a different saved view called “An example workspace”.

Customize runs

The following code cells show you how to filter, change the color, group, and sort runs programmatically.

In each example, the general workflow is to specify the desired customization as an argument to the appropiate parameter in ws.RunsetSettings.

Filter runs

You can create filters with python expressions and metrics you log with wandb.log or that are logged automatically as part of the run such as Created Timestamp. You can also reference filters by how they appear in the W&B App UI such as the Name, Tags, or ID.

The following example shows how to filter runs based on the validation loss summary, validation accuracy summary, and the regex specified:

def advanced_filter_example(entity: str, project: str) -> None:
    # Get all runs in the project
    runs: list = wandb.Api().runs(f"{entity}/{project}")

    # Apply multiple filters: val_loss < 0.1, val_accuracy > 0.8, and run name matches regex pattern
    workspace: ws.Workspace = ws.Workspace(
        name="Advanced Filtered Workspace with Regex",
        entity=entity,
        project=project,
        sections=[
            ws.Section(
                name="Advanced Filtered Section",
                panels=[
                    wr.LinePlot(x="Step", y=["val_loss"]),
                    wr.LinePlot(x="Step", y=["val_accuracy"]),
                ],
                is_open=True,
            ),
        ],
        runset_settings=ws.RunsetSettings(
            filters=[
                (ws.Summary("val_loss") < 0.1),  # Filter runs by the 'val_loss' summary
                (ws.Summary("val_accuracy") > 0.8),  # Filter runs by the 'val_accuracy' summary
                (ws.Metric("ID").isin([run.id for run in wandb.Api().runs(f"{entity}/{project}")])),
            ],
            regex_query=True,
        )
    )

    # Add regex search to match run names starting with 's'
    workspace.runset_settings.query = "^s"
    workspace.runset_settings.regex_query = True

    workspace.save()
    print("Workspace with advanced filters and regex search saved.")

advanced_filter_example(entity, project)

Note that passing in a list of filter expressions applies the boolean “AND” logic.

Change the colors of runs

This example demonstrates how to change the colors of the runs in a workspace:

def run_color_example(entity: str, project: str) -> None:
    # Get all runs in the project
    runs: list = wandb.Api().runs(f"{entity}/{project}")

    # Dynamically assign colors to the runs
    run_colors: list = ['purple', 'orange', 'teal', 'magenta']
    run_settings: dict = {}
    for i, run in enumerate(runs):
        run_settings[run.id] = ws.RunSettings(color=run_colors[i % len(run_colors)])

    workspace: ws.Workspace = ws.Workspace(
        name="Run Colors Workspace",
        entity=entity,
        project=project,
        sections=[
            ws.Section(
                name="Run Colors Section",
                panels=[
                    wr.LinePlot(x="Step", y=["val_loss"]),
                    wr.LinePlot(x="Step", y=["val_accuracy"]),
                ],
                is_open=True,
            ),
        ],
        runset_settings=ws.RunsetSettings(
            run_settings=run_settings
        )
    )

    workspace.save()
    print("Workspace with run colors saved.")

run_color_example(entity, project)

Group runs

This example demonstrates how to group runs by specific metrics.

def grouping_example(entity: str, project: str) -> None:
    workspace: ws.Workspace = ws.Workspace(
        name="Grouped Runs Workspace",
        entity=entity,
        project=project,
        sections=[
            ws.Section(
                name="Grouped Runs",
                panels=[
                    wr.LinePlot(x="Step", y=["val_loss"]),
                    wr.LinePlot(x="Step", y=["val_accuracy"]),
                ],
                is_open=True,
            ),
        ],
        runset_settings=ws.RunsetSettings(
            groupby=[ws.Metric("Name")]
        )
    )
    workspace.save()
    print("Workspace with grouped runs saved.")

grouping_example(entity, project)

Sort runs

This example demonstrates how to sort runs based on the validation loss summary:

def sorting_example(entity: str, project: str) -> None:
    workspace: ws.Workspace = ws.Workspace(
        name="Sorted Runs Workspace",
        entity=entity,
        project=project,
        sections=[
            ws.Section(
                name="Sorted Runs",
                panels=[
                    wr.LinePlot(x="Step", y=["val_loss"]),
                    wr.LinePlot(x="Step", y=["val_accuracy"]),
                ],
                is_open=True,
            ),
        ],
        runset_settings=ws.RunsetSettings(
            order=[ws.Ordering(ws.Summary("val_loss"))] #Order using val_loss summary
        )
    )
    workspace.save()
    print("Workspace with sorted runs saved.")

sorting_example(entity, project)

4. Putting it all together: comprehenive example

This example demonstrates how to create a comprehensive workspace, configure its settings, and add panels to sections:

def full_end_to_end_example(entity: str, project: str) -> None:
    # Get all runs in the project
    runs: list = wandb.Api().runs(f"{entity}/{project}")

    # Dynamically assign colors to the runs and create run settings
    run_colors: list = ['red', 'blue', 'green', 'orange', 'purple', 'teal', 'magenta', '#FAC13C']
    run_settings: dict = {}
    for i, run in enumerate(runs):
        run_settings[run.id] = ws.RunSettings(color=run_colors[i % len(run_colors)], disabled=False)

    workspace: ws.Workspace = ws.Workspace(
        name="My Workspace Template",
        entity=entity,
        project=project,
        sections=[
            ws.Section(
                name="Main Metrics",
                panels=[
                    wr.LinePlot(x="Step", y=["val_loss"]),
                    wr.LinePlot(x="Step", y=["val_accuracy"]),
                    wr.ScalarChart(metric="f1_score", groupby_aggfunc="mean"),
                ],
                is_open=True,
            ),
            ws.Section(
                name="Additional Metrics",
                panels=[
                    wr.ScalarChart(metric="precision", groupby_aggfunc="mean"),
                    wr.ScalarChart(metric="recall", groupby_aggfunc="mean"),
                ],
            ),
        ],
        settings=ws.WorkspaceSettings(
            x_axis="Step",
            x_min=0,
            x_max=100,
            smoothing_type="none",
            smoothing_weight=0,
            ignore_outliers=False,
            remove_legends_from_panels=False,
            tooltip_number_of_runs="default",
            tooltip_color_run_names=True,
            max_runs=20,
            point_visualization_method="bucketing",
            auto_expand_panel_search_results=False,
        ),
        runset_settings=ws.RunsetSettings(
            query="",
            regex_query=False,
            filters=[
                ws.Summary("val_loss") < 1,
                ws.Metric("Name") == "sample_run",
            ],
            groupby=[ws.Metric("Name")],
            order=[ws.Ordering(ws.Summary("Step"), ascending=True)],
            run_settings=run_settings
        )
    )
    workspace.save()
    print("Workspace created and saved.")

full_end_to_end_example(entity, project)

6 - Integration tutorials

6.1 - PyTorch

Use Weights & Biases for machine learning experiment tracking, dataset versioning, and project collaboration.

What this notebook covers

We show you how to integrate Weights & Biases with your PyTorch code to add experiment tracking to your pipeline.

# import the library
import wandb

# start a new experiment
wandb.init(project="new-sota-model")

# capture a dictionary of hyperparameters with config
wandb.config = {"learning_rate": 0.001, "epochs": 100, "batch_size": 128}

# set up model and data
model, dataloader = get_model(), get_data()

# optional: track gradients
wandb.watch(model)

for batch in dataloader:
  metrics = model.training_step()
  # log metrics inside your training loop to visualize model performance
  wandb.log(metrics)

# optional: save model at the end
model.to_onnx()
wandb.save("model.onnx")

Follow along with a video tutorial.

Note: Sections starting with Step are all you need to integrate W&B in an existing pipeline. The rest just loads data and defines a model.

Install, import, and log in

import os
import random

import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from tqdm.auto import tqdm

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
random.seed(hash("setting random seeds") % 2**32 - 1)
np.random.seed(hash("improves reproducibility") % 2**32 - 1)
torch.manual_seed(hash("by removing stochasticity") % 2**32 - 1)
torch.cuda.manual_seed_all(hash("so runs are repeatable") % 2**32 - 1)

# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# remove slow mirror from list of MNIST mirrors
torchvision.datasets.MNIST.mirrors = [mirror for mirror in torchvision.datasets.MNIST.mirrors
                                      if not mirror.startswith("http://yann.lecun.com")]

0️⃣ Step 0: Install W&B

To get started, we’ll need to get the library. wandb is easily installed using pip.

!pip install wandb onnx -Uq

1️⃣ Step 1: Import W&B and Login

In order to log data to our web service, you’ll need to log in.

If this is your first time using W&B, you’ll need to sign up for a free account at the link that appears.

import wandb

wandb.login()

Define the Experiment and Pipeline

Track metadata and hyperparameters with wandb.init

Programmatically, the first thing we do is define our experiment: what are the hyperparameters? what metadata is associated with this run?

It’s a pretty common workflow to store this information in a config dictionary (or similar object) and then access it as needed.

For this example, we’re only letting a few hyperparameters vary and hand-coding the rest. But any part of your model can be part of the config.

We also include some metadata: we’re using the MNIST dataset and a convolutional architecture. If we later work with, say, fully connected architectures on CIFAR in the same project, this will help us separate our runs.

config = dict(
    epochs=5,
    classes=10,
    kernels=[16, 32],
    batch_size=128,
    learning_rate=0.005,
    dataset="MNIST",
    architecture="CNN")

Now, let’s define the overall pipeline, which is pretty typical for model-training:

  1. we first make a model, plus associated data and optimizer, then
  2. we train the model accordingly and finally
  3. test it to see how training went.

We’ll implement these functions below.

def model_pipeline(hyperparameters):

    # tell wandb to get started
    with wandb.init(project="pytorch-demo", config=hyperparameters):
      # access all HPs through wandb.config, so logging matches execution.
      config = wandb.config

      # make the model, data, and optimization problem
      model, train_loader, test_loader, criterion, optimizer = make(config)
      print(model)

      # and use them to train the model
      train(model, train_loader, criterion, optimizer, config)

      # and test its final performance
      test(model, test_loader)

    return model

The only difference here from a standard pipeline is that it all occurs inside the context of wandb.init. Calling this function sets up a line of communication between your code and our servers.

Passing the config dictionary to wandb.init immediately logs all that information to us, so you’ll always know what hyperparameter values you set your experiment to use.

To ensure the values you chose and logged are always the ones that get used in your model, we recommend using the wandb.config copy of your object. Check the definition of make below to see some examples.

Side Note: We take care to run our code in separate processes, so that any issues on our end (such as if a giant sea monster attacks our data centers) don’t crash your code. Once the issue is resolved, such as when the Kraken returns to the deep, you can log the data with wandb sync.

def make(config):
    # Make the data
    train, test = get_data(train=True), get_data(train=False)
    train_loader = make_loader(train, batch_size=config.batch_size)
    test_loader = make_loader(test, batch_size=config.batch_size)

    # Make the model
    model = ConvNet(config.kernels, config.classes).to(device)

    # Make the loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        model.parameters(), lr=config.learning_rate)
    
    return model, train_loader, test_loader, criterion, optimizer

Define the Data Loading and Model

Now, we need to specify how the data is loaded and what the model looks like.

This part is very important, but it’s no different from what it would be without wandb, so we won’t dwell on it.

def get_data(slice=5, train=True):
    full_dataset = torchvision.datasets.MNIST(root=".",
                                              train=train, 
                                              transform=transforms.ToTensor(),
                                              download=True)
    #  equiv to slicing with [::slice] 
    sub_dataset = torch.utils.data.Subset(
      full_dataset, indices=range(0, len(full_dataset), slice))
    
    return sub_dataset


def make_loader(dataset, batch_size):
    loader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=batch_size, 
                                         shuffle=True,
                                         pin_memory=True, num_workers=2)
    return loader

Defining the model is normally the fun part.

But nothing changes with wandb, so we’re gonna stick with a standard ConvNet architecture.

Don’t be afraid to mess around with this and try some experiments – all your results will be logged on wandb.ai.

# Conventional and convolutional neural network

class ConvNet(nn.Module):
    def __init__(self, kernels, classes=10):
        super(ConvNet, self).__init__()
        
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, kernels[0], kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, kernels[1], kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7 * 7 * kernels[-1], classes)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

Define Training Logic

Moving on in our model_pipeline, it’s time to specify how we train.

Two wandb functions come into play here: watch and log.

Track gradients with wandb.watch and everything else with wandb.log

wandb.watch will log the gradients and the parameters of your model, every log_freq steps of training.

All you need to do is call it before you start training.

The rest of the training code remains the same: we iterate over epochs and batches, running forward and backward passes and applying our optimizer.

def train(model, loader, criterion, optimizer, config):
    # Tell wandb to watch what the model gets up to: gradients, weights, and more.
    wandb.watch(model, criterion, log="all", log_freq=10)

    # Run training and track with wandb
    total_batches = len(loader) * config.epochs
    example_ct = 0  # number of examples seen
    batch_ct = 0
    for epoch in tqdm(range(config.epochs)):
        for _, (images, labels) in enumerate(loader):

            loss = train_batch(images, labels, model, optimizer, criterion)
            example_ct +=  len(images)
            batch_ct += 1

            # Report metrics every 25th batch
            if ((batch_ct + 1) % 25) == 0:
                train_log(loss, example_ct, epoch)


def train_batch(images, labels, model, optimizer, criterion):
    images, labels = images.to(device), labels.to(device)
    
    # Forward pass ➡
    outputs = model(images)
    loss = criterion(outputs, labels)
    
    # Backward pass ⬅
    optimizer.zero_grad()
    loss.backward()

    # Step with optimizer
    optimizer.step()

    return loss

The only difference is in the logging code: where previously you might have reported metrics by printing to the terminal, now you pass the same information to wandb.log.

wandb.log expects a dictionary with strings as keys. These strings identify the objects being logged, which make up the values. You can also optionally log which step of training you’re on.

Side Note: I like to use the number of examples the model has seen, since this makes for easier comparison across batch sizes, but you can use raw steps or batch count. For longer training runs, it can also make sense to log by epoch.

def train_log(loss, example_ct, epoch):
    # Where the magic happens
    wandb.log({"epoch": epoch, "loss": loss}, step=example_ct)
    print(f"Loss after {str(example_ct).zfill(5)} examples: {loss:.3f}")

Define Testing Logic

Once the model is done training, we want to test it: run it against some fresh data from production, perhaps, or apply it to some hand-curated examples.

(Optional) Call wandb.save

This is also a great time to save the model’s architecture and final parameters to disk. For maximum compatibility, we’ll export our model in the Open Neural Network eXchange (ONNX) format.

Passing that filename to wandb.save ensures that the model parameters are saved to W&B’s servers: no more losing track of which .h5 or .pb corresponds to which training runs.

For more advanced wandb features for storing, versioning, and distributing models, check out our Artifacts tools.

def test(model, test_loader):
    model.eval()

    # Run the model on some test examples
    with torch.no_grad():
        correct, total = 0, 0
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f"Accuracy of the model on the {total} " +
              f"test images: {correct / total:%}")
        
        wandb.log({"test_accuracy": correct / total})

    # Save the model in the exchangeable ONNX format
    torch.onnx.export(model, images, "model.onnx")
    wandb.save("model.onnx")

Run training and watch your metrics live on wandb.ai

Now that we’ve defined the whole pipeline and slipped in those few lines of W&B code, we’re ready to run our fully tracked experiment.

We’ll report a few links to you: our documentation, the Project page, which organizes all the runs in a project, and the Run page, where this run’s results will be stored.

Navigate to the Run page and check out these tabs:

  1. Charts, where the model gradients, parameter values, and loss are logged throughout training
  2. System, which contains a variety of system metrics, including Disk I/O utilization, CPU and GPU metrics (watch that temperature soar 🔥), and more
  3. Logs, which has a copy of anything pushed to standard out during training
  4. Files, where, once training is complete, you can click on the model.onnx to view our network with the Netron model viewer.

Once the run in finished, when the with wandb.init block exits, we’ll also print a summary of the results in the cell output.

# Build, train and analyze the model with the pipeline
model = model_pipeline(config)

Test Hyperparameters with Sweeps

We only looked at a single set of hyperparameters in this example. But an important part of most ML workflows is iterating over a number of hyperparameters.

You can use Weights & Biases Sweeps to automate hyperparameter testing and explore the space of possible models and optimization strategies.

Check out Hyperparameter Optimization in PyTorch using W&B Sweeps

Running a hyperparameter sweep with Weights & Biases is very easy. There are just 3 simple steps:

  1. Define the sweep: We do this by creating a dictionary or a YAML file that specifies the parameters to search through, the search strategy, the optimization metric et all.

  2. Initialize the sweep: sweep_id = wandb.sweep(sweep_config)

  3. Run the sweep agent: wandb.agent(sweep_id, function=train)

That’s all there is to running a hyperparameter sweep.

See examples of projects tracked and visualized with W&B in our Gallery →

Advanced Setup

  1. Environment variables: Set API keys in environment variables so you can run training on a managed cluster.
  2. Offline mode: Use dryrun mode to train offline and sync results later.
  3. On-prem: Install W&B in a private cloud or air-gapped servers in your own infrastructure. We have local installations for everyone from academics to enterprise teams.
  4. Sweeps: Set up hyperparameter search quickly with our lightweight tool for tuning.

6.2 - PyTorch Lightning

We will build an image classification pipeline using PyTorch Lightning. We will follow this style guide to increase the readability and reproducibility of our code. A cool explanation of this available here.

Setting up PyTorch Lightning and W&B

For this tutorial, we need PyTorch Lightning and Weights and Biases.

pip install lightning -q
pip install wandb -qU
import lightning.pytorch as pl

# your favorite machine learning tracking tool
from lightning.pytorch.loggers import WandbLogger

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoader

from torchmetrics import Accuracy

from torchvision import transforms
from torchvision.datasets import CIFAR10

import wandb

Now you’ll need to log in to your wandb account.

wandb.login()

DataModule - The Data Pipeline we Deserve

DataModules are a way of decoupling data-related hooks from the LightningModule so you can develop dataset agnostic models.

It organizes the data pipeline into one shareable and reusable class. A datamodule encapsulates the five steps involved in data processing in PyTorch:

  • Download / tokenize / process.
  • Clean and (maybe) save to disk.
  • Load inside Dataset.
  • Apply transforms (rotate, tokenize, etc…).
  • Wrap inside a DataLoader.

Learn more about datamodules here. Let’s build a datamodule for the Cifar-10 dataset.

class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, batch_size, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        self.num_classes = 10
    
    def prepare_data(self):
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)
    
    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)
    
    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=self.batch_size)

Callbacks

A callback is a self-contained program that can be reused across projects. PyTorch Lightning comes with few built-in callbacks which are regularly used. Learn more about callbacks in PyTorch Lightning here.

Built-in Callbacks

In this tutorial, we will use Early Stopping and Model Checkpoint built-in callbacks. They can be passed to the Trainer.

Custom Callbacks

If you are familiar with Custom Keras callback, the ability to do the same in your PyTorch pipeline is just a cherry on the cake.

Since we are performing image classification, the ability to visualize the model’s predictions on some samples of images can be helpful. This in the form of a callback can help debug the model at an early stage.

class ImagePredictionLogger(pl.callbacks.Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples
    
    def on_validation_epoch_end(self, trainer, pl_module):
        # Bring the tensors to CPU
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)
        # Get model prediction
        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)
        # Log the images as wandb Image
        trainer.logger.experiment.log({
            "examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") 
                           for x, pred, y in zip(val_imgs[:self.num_samples], 
                                                 preds[:self.num_samples], 
                                                 val_labels[:self.num_samples])]
            })
        

LightningModule - Define the System

The LightningModule defines a system and not a model. Here a system groups all the research code into a single class to make it self-contained. LightningModule organizes your PyTorch code into 5 sections:

  • Computations (__init__).
  • Train loop (training_step)
  • Validation loop (validation_step)
  • Test loop (test_step)
  • Optimizers (configure_optimizers)

One can thus build a dataset agnostic model that can be easily shared. Let’s build a system for Cifar-10 classification.

class LitModel(pl.LightningModule):
    def __init__(self, input_shape, num_classes, learning_rate=2e-4):
        super().__init__()
        
        # log hyperparameters
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 32, 3, 1)
        self.conv3 = nn.Conv2d(32, 64, 3, 1)
        self.conv4 = nn.Conv2d(64, 64, 3, 1)

        self.pool1 = torch.nn.MaxPool2d(2)
        self.pool2 = torch.nn.MaxPool2d(2)
        
        n_sizes = self._get_conv_output(input_shape)

        self.fc1 = nn.Linear(n_sizes, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, num_classes)

        self.accuracy = Accuracy(task='multiclass', num_classes=num_classes)

    # returns the size of the output tensor going into Linear layer from the conv block.
    def _get_conv_output(self, shape):
        batch_size = 1
        input = torch.autograd.Variable(torch.rand(batch_size, *shape))

        output_feat = self._forward_features(input) 
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size
        
    # returns the feature tensor from the conv block
    def _forward_features(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool2(F.relu(self.conv4(x)))
        return x
    
    # will be used during inference
    def forward(self, x):
       x = self._forward_features(x)
       x = x.view(x.size(0), -1)
       x = F.relu(self.fc1(x))
       x = F.relu(self.fc2(x))
       x = F.log_softmax(self.fc3(x), dim=1)
       
       return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # training metrics
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

Train and Evaluate

Now that we have organized our data pipeline using DataModule and model architecture+training loop using LightningModule, the PyTorch Lightning Trainer automates everything else for us.

The Trainer automates:

  • Epoch and batch iteration
  • Calling of optimizer.step(), backward, zero_grad()
  • Calling of .eval(), enabling/disabling grads
  • Saving and loading weights
  • Weights and Biases logging
  • Multi-GPU training support
  • TPU support
  • 16-bit training support
dm = CIFAR10DataModule(batch_size=32)
# To access the x_dataloader we need to call prepare_data and setup.
dm.prepare_data()
dm.setup()

# Samples required by the custom ImagePredictionLogger callback to log image predictions.
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape
model = LitModel((3, 32, 32), dm.num_classes)

# Initialize wandb logger
wandb_logger = WandbLogger(project='wandb-lightning', job_type='train')

# Initialize Callbacks
early_stop_callback = pl.callbacks.EarlyStopping(monitor="val_loss")
checkpoint_callback = pl.callbacks.ModelCheckpoint()

# Initialize a trainer
trainer = pl.Trainer(max_epochs=2,
                     logger=wandb_logger,
                     callbacks=[early_stop_callback,
                                ImagePredictionLogger(val_samples),
                                checkpoint_callback],
                     )

# Train the model 
trainer.fit(model, dm)

# Evaluate the model on the held-out test set ⚡⚡
trainer.test(dataloaders=dm.test_dataloader())

# Close wandb run
wandb.finish()

Final Thoughts

I come from the TensorFlow/Keras ecosystem and find PyTorch a bit overwhelming even though it’s an elegant framework. Just my personal experience though. While exploring PyTorch Lightning, I realized that almost all of the reasons that kept me away from PyTorch is taken care of. Here’s a quick summary of my excitement:

  • Then: Conventional PyTorch model definition used to be all over the place. With the model in some model.py script and the training loop in the train.py file. It was a lot of looking back and forth to understand the pipeline.
  • Now: The LightningModule acts as a system where the model is defined along with the training_step, validation_step, etc. Now it’s modular and shareable.
  • Then: The best part about TensorFlow/Keras is the input data pipeline. Their dataset catalog is rich and growing. PyTorch’s data pipeline used to be the biggest pain point. In normal PyTorch code, the data download/cleaning/preparation is usually scattered across many files.
  • Now: The DataModule organizes the data pipeline into one shareable and reusable class. It’s simply a collection of a train_dataloader, val_dataloader(s), test_dataloader(s) along with the matching transforms and data processing/downloads steps required.
  • Then: With Keras, one can call model.fit to train the model and model.predict to run inference on. model.evaluate offered a good old simple evaluation on the test data. This is not the case with PyTorch. One will usually find separate train.py and test.py files.
  • Now: With the LightningModule in place, the Trainer automates everything. One needs to just call trainer.fit and trainer.test to train and evaluate the model.
  • Then: TensorFlow loves TPU, PyTorch…
  • Now: With PyTorch Lightning, it’s so easy to train the same model with multiple GPUs and even on TPU.
  • Then: I am a big fan of Callbacks and prefer writing custom callbacks. Something as trivial as Early Stopping used to be a point of discussion with conventional PyTorch.
  • Now: With PyTorch Lightning using Early Stopping and Model Checkpointing is a piece of cake. I can even write custom callbacks.

🎨 Conclusion and Resources

I hope you find this report helpful. I will encourage to play with the code and train an image classifier with a dataset of your choice.

Here are some resources to learn more about PyTorch Lightning:

6.3 - Hugging Face

Visualize your Hugging Face model’s performance quickly with a seamless W&B integration.

Compare hyperparameters, output metrics, and system stats like GPU utilization across your models.

Why should I use W&B?

  • Unified dashboard: Central repository for all your model metrics and predictions
  • Lightweight: No code changes required to integrate with Hugging Face
  • Accessible: Free for individuals and academic teams
  • Secure: All projects are private by default
  • Trusted: Used by machine learning teams at OpenAI, Toyota, Lyft and more

Think of W&B like GitHub for machine learning models— save machine learning experiments to your private, hosted dashboard. Experiment quickly with the confidence that all the versions of your models are saved for you, no matter where you’re running your scripts.

W&B lightweight integrations works with any Python script, and all you need to do is sign up for a free W&B account to start tracking and visualizing your models.

In the Hugging Face Transformers repo, we’ve instrumented the Trainer to automatically log training and evaluation metrics to W&B at each logging step.

Here’s an in depth look at how the integration works: Hugging Face + W&B Report.

Install, import, and log in

Install the Hugging Face and Weights & Biases libraries, and the GLUE dataset and training script for this tutorial.

!pip install datasets wandb evaluate accelerate -qU
!wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/pytorch/text-classification/run_glue.py
# the run_glue.py script requires transformers dev
!pip install -q git+https://github.com/huggingface/transformers

Before continuing, sign up for a free account.

Put in your API key

Once you’ve signed up, run the next cell and click on the link to get your API key and authenticate this notebook.

import wandb
wandb.login()

Optionally, we can set environment variables to customize W&B logging. See documentation.

# Optional: log both gradients and parameters
%env WANDB_WATCH=all

Train the model

Next, call the downloaded training script run_glue.py and see training automatically get tracked to the Weights & Biases dashboard. This script fine-tunes BERT on the Microsoft Research Paraphrase Corpus— pairs of sentences with human annotations indicating whether they are semantically equivalent.

%env WANDB_PROJECT=huggingface-demo
%env TASK_NAME=MRPC

!python run_glue.py \
  --model_name_or_path bert-base-uncased \
  --task_name $TASK_NAME \
  --do_train \
  --do_eval \
  --max_seq_length 256 \
  --per_device_train_batch_size 32 \
  --learning_rate 2e-4 \
  --num_train_epochs 3 \
  --output_dir /tmp/$TASK_NAME/ \
  --overwrite_output_dir \
  --logging_steps 50

Visualize results in dashboard

Click the link printed out above, or go to wandb.ai to see your results stream in live. The link to see your run in the browser will appear after all the dependencies are loaded. Look for the following output: “wandb: 🚀 View run at [URL to your unique run]”

Visualize Model Performance It’s easy to look across dozens of experiments, zoom in on interesting findings, and visualize highly dimensional data.

Compare Architectures Here’s an example comparing BERT vs DistilBERT. It’s easy to see how different architectures effect the evaluation accuracy throughout training with automatic line plot visualizations.

Track key information effortlessly by default

Weights & Biases saves a new run for each experiment. Here’s the information that gets saved by default:

  • Hyperparameters: Settings for your model are saved in Config
  • Model Metrics: Time series data of metrics streaming in are saved in Log
  • Terminal Logs: Command line outputs are saved and available in a tab
  • System Metrics: GPU and CPU utilization, memory, temperature etc.

Learn more

  • Documentation: docs on the Weights & Biases and Hugging Face integration
  • Videos: tutorials, interviews with practitioners, and more on our YouTube channel
  • Contact: Message us at contact@wandb.com with questions

6.4 - TensorFlow

Use Weights & Biases for machine learning experiment tracking, dataset versioning, and project collaboration.

What this notebook covers

  • Easy integration of Weights and Biases with your TensorFlow pipeline for experiment tracking.
  • Computing metrics with keras.metrics
  • Using wandb.log to log those metrics in your custom training loop.
dashboard

Note: Sections starting with Step are all you need to integrate W&B into existing code. The rest is just a standard MNIST example.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.datasets import cifar10

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

Install, Import, Login

Install W&B

%%capture
!pip install wandb

Import W&B and login

import wandb
from wandb.integration.keras import WandbMetricsLogger

wandb.login()

Side note: If this is your first time using W&B or you are not logged in, the link that appears after running wandb.login() will take you to sign-up/login page. Signing up is as easy as one click.

Prepare Dataset

# Prepare the training dataset
BATCH_SIZE = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))

# build input pipeline using tf.data
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(BATCH_SIZE)

val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(BATCH_SIZE)

Define the Model and the Training Loop

def make_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x1 = keras.layers.Dense(64, activation="relu")(inputs)
    x2 = keras.layers.Dense(64, activation="relu")(x1)
    outputs = keras.layers.Dense(10, name="predictions")(x2)

    return keras.Model(inputs=inputs, outputs=outputs)
def train_step(x, y, model, optimizer, loss_fn, train_acc_metric):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)

    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))

    train_acc_metric.update_state(y, logits)

    return loss_value
def test_step(x, y, model, loss_fn, val_acc_metric):
    val_logits = model(x, training=False)
    loss_value = loss_fn(y, val_logits)
    val_acc_metric.update_state(y, val_logits)

    return loss_value

Add wandb.log to your training loop

def train(train_dataset, val_dataset,  model, optimizer,
          train_acc_metric, val_acc_metric,
          epochs=10,  log_step=200, val_log_step=50):
  
    for epoch in range(epochs):
        print("\nStart of epoch %d" % (epoch,))

        train_loss = []   
        val_loss = []

        # Iterate over the batches of the dataset
        for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
            loss_value = train_step(x_batch_train, y_batch_train, 
                                    model, optimizer, 
                                    loss_fn, train_acc_metric)
            train_loss.append(float(loss_value))

        # Run a validation loop at the end of each epoch
        for step, (x_batch_val, y_batch_val) in enumerate(val_dataset):
            val_loss_value = test_step(x_batch_val, y_batch_val, 
                                       model, loss_fn, 
                                       val_acc_metric)
            val_loss.append(float(val_loss_value))
            
        # Display metrics at the end of each epoch
        train_acc = train_acc_metric.result()
        print("Training acc over epoch: %.4f" % (float(train_acc),))

        val_acc = val_acc_metric.result()
        print("Validation acc: %.4f" % (float(val_acc),))

        # Reset metrics at the end of each epoch
        train_acc_metric.reset_states()
        val_acc_metric.reset_states()

        # ⭐: log metrics using wandb.log
        wandb.log({'epochs': epoch,
                   'loss': np.mean(train_loss),
                   'acc': float(train_acc), 
                   'val_loss': np.mean(val_loss),
                   'val_acc':float(val_acc)})

Run Training

Call wandb.init to start a run

This lets us know you’re launching an experiment, so we can give it a unique ID and a dashboard.

Check out the official documentation

# initialize wandb with your project name and optionally with configutations.
# play around with the config values and see the result on your wandb dashboard.
config = {
              "learning_rate": 0.001,
              "epochs": 10,
              "batch_size": 64,
              "log_step": 200,
              "val_log_step": 50,
              "architecture": "CNN",
              "dataset": "CIFAR-10"
           }

run = wandb.init(project='my-tf-integration', config=config)
config = wandb.config

# Initialize model.
model = make_model()

# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.SGD(learning_rate=config.learning_rate)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Prepare the metrics.
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

train(train_dataset,
      val_dataset, 
      model,
      optimizer,
      train_acc_metric,
      val_acc_metric,
      epochs=config.epochs, 
      log_step=config.log_step, 
      val_log_step=config.val_log_step)

run.finish()  # In Jupyter/Colab, let us know you're finished!

Visualize Results

Click on the run page link above to see your live results.

Sweep 101

Use Weights & Biases Sweeps to automate hyperparameter optimization and explore the space of possible models.

Check out Hyperparameter Optimization in TensorFlow using W&B Sweeps

Benefits of using W&B Sweeps

  • Quick setup: With just a few lines of code you can run W&B sweeps.
  • Transparent: We cite all the algorithms we’re using, and our code is open source.
  • Powerful: Our sweeps are completely customizable and configurable. You can launch a sweep across dozens of machines, and it’s just as easy as starting a sweep on your laptop.
Sweep result

See examples of projects tracked and visualized with W&B in our gallery of examples, Fully Connected →

📏 Best Practices

  1. Projects: Log multiple runs to a project to compare them. wandb.init(project="project-name")
  2. Groups: For multiple processes or cross validation folds, log each process as a runs and group them together. wandb.init(group='experiment-1')
  3. Tags: Add tags to track your current baseline or production model.
  4. Notes: Type notes in the table to track the changes between runs.
  5. Reports: Take quick notes on progress to share with colleagues and make dashboards and snapshots of your ML projects.

Advanced Setup

  1. Environment variables: Set API keys in environment variables so you can run training on a managed cluster.
  2. Offline mode
  3. On-prem: Install W&B in a private cloud or air-gapped servers in your own infrastructure. We have local installations for everyone from academics to enterprise teams.
  4. Artifacts: Track and version models and datasets in a streamlined way that automatically picks up your pipeline steps as you train models.

6.5 - TensorFlow Sweeps

Use Weights & Biases for machine learning experiment tracking, dataset versioning, and project collaboration.

Use Weights & Biases Sweeps to automate hyperparameter optimization and explore the space of possible models, complete with interactive dashboards like this:

Why Should I Use Sweeps?

  • Quick setup: With just a few lines of code, you can run W&B sweeps.
  • Transparent: The project cites all algorithms used, and the code is open source.
  • Powerful: Sweeps are completely customizable and configurable. You can launch a sweep across dozens of machines, and it’s just as easy as starting a sweep on your laptop.

Check out the official documentation

What this notebook covers

  • Simple steps to get started with W&B Sweep with custom training loop in TensorFlow.
  • Finding the best hyperparameters for an image classification task.

Note: Sections starting with Step are all you need to perform hyperparameter sweep in existing code. The rest of the code is there to set up a simple example.

Install, Import, and Log in

Install W&B

%%capture
!pip install wandb

Import W&B and Login

import tqdm
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.datasets import cifar10

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import wandb
from wandb.integration.keras import WandbMetricsLogger

wandb.login()

Side note: If this is your first time using W&B or you are not logged in, the link that appears after running wandb.login() will take you to sign-up/login page. Signing up is as easy as a few clicks.

Prepare Dataset

# Prepare the training dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train = x_train / 255.0
x_test = x_test / 255.0
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))

Build a Simple Classifier MLP

def Model():
    inputs = keras.Input(shape=(784,), name="digits")
    x1 = keras.layers.Dense(64, activation="relu")(inputs)
    x2 = keras.layers.Dense(64, activation="relu")(x1)
    outputs = keras.layers.Dense(10, name="predictions")(x2)

    return keras.Model(inputs=inputs, outputs=outputs)


def train_step(x, y, model, optimizer, loss_fn, train_acc_metric):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)

    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))

    train_acc_metric.update_state(y, logits)

    return loss_value


def test_step(x, y, model, loss_fn, val_acc_metric):
    val_logits = model(x, training=False)
    loss_value = loss_fn(y, val_logits)
    val_acc_metric.update_state(y, val_logits)

    return loss_value

Write a Training Loop

def train(
    train_dataset,
    val_dataset,
    model,
    optimizer,
    loss_fn,
    train_acc_metric,
    val_acc_metric,
    epochs=10,
    log_step=200,
    val_log_step=50,
):

    for epoch in range(epochs):
        print("\nStart of epoch %d" % (epoch,))

        train_loss = []
        val_loss = []

        # Iterate over the batches of the dataset
        for step, (x_batch_train, y_batch_train) in tqdm.tqdm(
            enumerate(train_dataset), total=len(train_dataset)
        ):
            loss_value = train_step(
                x_batch_train,
                y_batch_train,
                model,
                optimizer,
                loss_fn,
                train_acc_metric,
            )
            train_loss.append(float(loss_value))

        # Run a validation loop at the end of each epoch
        for step, (x_batch_val, y_batch_val) in enumerate(val_dataset):
            val_loss_value = test_step(
                x_batch_val, y_batch_val, model, loss_fn, val_acc_metric
            )
            val_loss.append(float(val_loss_value))

        # Display metrics at the end of each epoch
        train_acc = train_acc_metric.result()
        print("Training acc over epoch: %.4f" % (float(train_acc),))

        val_acc = val_acc_metric.result()
        print("Validation acc: %.4f" % (float(val_acc),))

        # Reset metrics at the end of each epoch
        train_acc_metric.reset_states()
        val_acc_metric.reset_states()

        # 3️⃣ log metrics using wandb.log
        wandb.log(
            {
                "epochs": epoch,
                "loss": np.mean(train_loss),
                "acc": float(train_acc),
                "val_loss": np.mean(val_loss),
                "val_acc": float(val_acc),
            }
        )

Configure the Sweep

This is where you will:

  • Define the hyperparameters you’re sweeping over
  • Provide your hyperparameter optimization method. We have random, grid and bayes methods.
  • Provide an objective and a metric if using bayes, for example to minimize the val_loss.
  • Use hyperband for early termination of poorly performing runs.

Check out more on Sweep Configs

sweep_config = {
    "method": "random",
    "metric": {"name": "val_loss", "goal": "minimize"},
    "early_terminate": {"type": "hyperband", "min_iter": 5},
    "parameters": {
        "batch_size": {"values": [32, 64, 128, 256]},
        "learning_rate": {"values": [0.01, 0.005, 0.001, 0.0005, 0.0001]},
    },
}

Wrap the Training Loop

You’ll need a function, like sweep_train below, that uses wandb.config to set the hyperparameters before train gets called.

def sweep_train(config_defaults=None):
    # Set default values
    config_defaults = {"batch_size": 64, "learning_rate": 0.01}
    # Initialize wandb with a sample project name
    wandb.init(config=config_defaults)  # this gets over-written in the Sweep

    # Specify the other hyperparameters to the configuration, if any
    wandb.config.epochs = 2
    wandb.config.log_step = 20
    wandb.config.val_log_step = 50
    wandb.config.architecture_name = "MLP"
    wandb.config.dataset_name = "MNIST"

    # build input pipeline using tf.data
    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    train_dataset = (
        train_dataset.shuffle(buffer_size=1024)
        .batch(wandb.config.batch_size)
        .prefetch(buffer_size=tf.data.AUTOTUNE)
    )

    val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    val_dataset = val_dataset.batch(wandb.config.batch_size).prefetch(
        buffer_size=tf.data.AUTOTUNE
    )

    # initialize model
    model = Model()

    # Instantiate an optimizer to train the model.
    optimizer = keras.optimizers.SGD(learning_rate=wandb.config.learning_rate)
    # Instantiate a loss function.
    loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

    # Prepare the metrics.
    train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
    val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

    train(
        train_dataset,
        val_dataset,
        model,
        optimizer,
        loss_fn,
        train_acc_metric,
        val_acc_metric,
        epochs=wandb.config.epochs,
        log_step=wandb.config.log_step,
        val_log_step=wandb.config.val_log_step,
    )

Initialize Sweep and Run Agent

sweep_id = wandb.sweep(sweep_config, project="sweeps-tensorflow")

You can limit the number of total runs with the count parameter, we will limit a 10 to make the script run fast, feel free to increase the number of runs and see what happens.

wandb.agent(sweep_id, function=sweep_train, count=10)

Visualize Results

Click on the Sweep URL link above to see your live results.

See examples of projects tracked and visualized with W&B in the Gallery →

Best Practices

  1. Projects: Log multiple runs to a project to compare them. wandb.init(project="project-name")
  2. Groups: For multiple processes or cross validation folds, log each process as a runs and group them together. wandb.init(group='experiment-1')
  3. Tags: Add tags to track your current baseline or production model.
  4. Notes: Type notes in the table to track the changes between runs.
  5. Reports: Take quick notes on progress to share with colleagues and make dashboards and snapshots of your ML projects.

Advanced Setup

  1. Environment variables: Set API keys in environment variables so you can run training on a managed cluster.
  2. Offline mode
  3. On-prem: Install W&B in a private cloud or air-gapped servers in your own infrastructure. Everyone from academics to enterprise teams use local installations.

6.6 - 3D brain tumor segmentation with MONAI

This tutorial demonstrates how to construct a training workflow of multi-labels 3D brain tumor segmentation task using MONAI and use experiment tracking and data visualization features of Weights & Biases. The tutorial contains the following features:

  1. Initialize a Weights & Biases run and synchronize all configs associated with the run for reproducibility.
  2. MONAI transform API:
    1. MONAI Transforms for dictionary format data.
    2. How to define a new transform according to MONAI transforms API.
    3. How to randomly adjust intensity for data augmentation.
  3. Data Loading and Visualization:
    1. Load Nifti image with metadata, load a list of images and stack them.
    2. Cache IO and transforms to accelerate training and validation.
    3. Visualize the data using wandb.Table and interactive segmentation overlay on Weights & Biases.
  4. Training a 3D SegResNet model
    1. Using the networks, losses, and metrics APIs from MONAI.
    2. Training the 3D SegResNet model using a PyTorch training loop.
    3. Track the training experiment using Weights & Biases.
    4. Log and version model checkpoints as model artifacts on Weights & Biases.
  5. Visualize and compare the predictions on the validation dataset using wandb.Table and interactive segmentation overlay on Weights & Biases.

Setup and Installation

First, install the latest version of both MONAI and Weights and Biases.

!python -c "import monai" || pip install -q -U "monai[nibabel, tqdm]"
!python -c "import wandb" || pip install -q -U wandb
import os

import numpy as np
from tqdm.auto import tqdm
import wandb

from monai.apps import DecathlonDataset
from monai.data import DataLoader, decollate_batch
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import SegResNet
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
)
from monai.utils import set_determinism

import torch

Then, authenticate the Colab instance to use W&B.

wandb.login()

Initialize a W&B Run

Start a new W&B run to start tracking the experiment.

wandb.init(project="monai-brain-tumor-segmentation")

Use of proper config system is a recommended best practice for reproducible machine learning. You can track the hyperparameters for every experiment using W&B.

config = wandb.config
config.seed = 0
config.roi_size = [224, 224, 144]
config.batch_size = 1
config.num_workers = 4
config.max_train_images_visualized = 20
config.max_val_images_visualized = 20
config.dice_loss_smoothen_numerator = 0
config.dice_loss_smoothen_denominator = 1e-5
config.dice_loss_squared_prediction = True
config.dice_loss_target_onehot = False
config.dice_loss_apply_sigmoid = True
config.initial_learning_rate = 1e-4
config.weight_decay = 1e-5
config.max_train_epochs = 50
config.validation_intervals = 1
config.dataset_dir = "./dataset/"
config.checkpoint_dir = "./checkpoints"
config.inference_roi_size = (128, 128, 64)
config.max_prediction_images_visualized = 20

You also need to set the random seed for modules to enable or turn off deterministic training.

set_determinism(seed=config.seed)

# Create directories
os.makedirs(config.dataset_dir, exist_ok=True)
os.makedirs(config.checkpoint_dir, exist_ok=True)

Data Loading and Transformation

Here, use the monai.transforms API to create a custom transform that converts the multi-classes labels into multi-labels segmentation task in one-hot format.

class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core
    The possible classes are TC (Tumor core), WT (Whole tumor)
    and ET (Enhancing tumor).

    Reference: https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb

    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # merge label 2 and label 3 to construct TC
            result.append(torch.logical_or(d[key] == 2, d[key] == 3))
            # merge labels 1, 2 and 3 to construct WT
            result.append(
                torch.logical_or(
                    torch.logical_or(d[key] == 2, d[key] == 3), d[key] == 1
                )
            )
            # label 2 is ET
            result.append(d[key] == 2)
            d[key] = torch.stack(result, axis=0).float()
        return d

Next, set up transforms for training and validation datasets respectively.

train_transform = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        RandSpatialCropd(
            keys=["image", "label"], roi_size=config.roi_size, random_size=False
        ),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
    ]
)
val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

The Dataset

The dataset used for this experiment comes from http://medicaldecathlon.com/. It uses multi-modal multi-site MRI data (FLAIR, T1w, T1gd, T2w) to segment Gliomas, necrotic/active tumour, and oedema. The dataset consists of 750 4D volumes (484 Training + 266 Testing).

Use the DecathlonDataset to automatically download and extract the dataset. It inherits MONAI CacheDataset which enables you to set cache_num=N to cache N items for training and use the default arguments to cache all the items for validation, depending on your memory size.

train_dataset = DecathlonDataset(
    root_dir=config.dataset_dir,
    task="Task01_BrainTumour",
    transform=val_transform,
    section="training",
    download=True,
    cache_rate=0.0,
    num_workers=4,
)
val_dataset = DecathlonDataset(
    root_dir=config.dataset_dir,
    task="Task01_BrainTumour",
    transform=val_transform,
    section="validation",
    download=False,
    cache_rate=0.0,
    num_workers=4,
)

Visualizing the Dataset

Weights & Biases supports images, video, audio, and more. You can log rich media to explore your results and visually compare our runs, models, and datasets. Use the segmentation mask overlay system to visualize our data volumes. To log segmentation masks in tables, you must provide a wandb.Image object for each row in the table.

An example is provided in the pseudocode below:

table = wandb.Table(columns=["ID", "Image"])

for id, img, label in zip(ids, images, labels):
    mask_img = wandb.Image(
        img,
        masks={
            "prediction": {"mask_data": label, "class_labels": class_labels}
            # ...
        },
    )

    table.add_data(id, img)

wandb.log({"Table": table})

Now write a simple utility function that takes a sample image, label, wandb.Table object and some associated metadata and populate the rows of a table that would be logged to the Weights & Biases dashboard.

def log_data_samples_into_tables(
    sample_image: np.array,
    sample_label: np.array,
    split: str = None,
    data_idx: int = None,
    table: wandb.Table = None,
):
    num_channels, _, _, num_slices = sample_image.shape
    with tqdm(total=num_slices, leave=False) as progress_bar:
        for slice_idx in range(num_slices):
            ground_truth_wandb_images = []
            for channel_idx in range(num_channels):
                ground_truth_wandb_images.append(
                    masks = {
                        "ground-truth/Tumor-Core": {
                            "mask_data": sample_label[0, :, :, slice_idx],
                            "class_labels": {0: "background", 1: "Tumor Core"},
                        },
                        "ground-truth/Whole-Tumor": {
                            "mask_data": sample_label[1, :, :, slice_idx] * 2,
                            "class_labels": {0: "background", 2: "Whole Tumor"},
                        },
                        "ground-truth/Enhancing-Tumor": {
                            "mask_data": sample_label[2, :, :, slice_idx] * 3,
                            "class_labels": {0: "background", 3: "Enhancing Tumor"},
                        },
                    }
                    wandb.Image(
                        sample_image[channel_idx, :, :, slice_idx],
                        masks=masks,
                    )
                )
            table.add_data(split, data_idx, slice_idx, *ground_truth_wandb_images)
            progress_bar.update(1)
    return table

Next, define the wandb.Table object and what columns it consists of so that it can populate with the data visualizations.

table = wandb.Table(
    columns=[
        "Split",
        "Data Index",
        "Slice Index",
        "Image-Channel-0",
        "Image-Channel-1",
        "Image-Channel-2",
        "Image-Channel-3",
    ]
)

Then, loop over the train_dataset and val_dataset respectively to generate the visualizations for the data samples and populate the rows of the table which to log to the dashboard.

# Generate visualizations for train_dataset
max_samples = (
    min(config.max_train_images_visualized, len(train_dataset))
    if config.max_train_images_visualized > 0
    else len(train_dataset)
)
progress_bar = tqdm(
    enumerate(train_dataset[:max_samples]),
    total=max_samples,
    desc="Generating Train Dataset Visualizations:",
)
for data_idx, sample in progress_bar:
    sample_image = sample["image"].detach().cpu().numpy()
    sample_label = sample["label"].detach().cpu().numpy()
    table = log_data_samples_into_tables(
        sample_image,
        sample_label,
        split="train",
        data_idx=data_idx,
        table=table,
    )

# Generate visualizations for val_dataset
max_samples = (
    min(config.max_val_images_visualized, len(val_dataset))
    if config.max_val_images_visualized > 0
    else len(val_dataset)
)
progress_bar = tqdm(
    enumerate(val_dataset[:max_samples]),
    total=max_samples,
    desc="Generating Validation Dataset Visualizations:",
)
for data_idx, sample in progress_bar:
    sample_image = sample["image"].detach().cpu().numpy()
    sample_label = sample["label"].detach().cpu().numpy()
    table = log_data_samples_into_tables(
        sample_image,
        sample_label,
        split="val",
        data_idx=data_idx,
        table=table,
    )

# Log the table to your dashboard
wandb.log({"Tumor-Segmentation-Data": table})

The data appears on the W&B dashboard in an interactive tabular format. We can see each channel of a particular slice from a data volume overlaid with the respective segmentation mask in each row. You can write Weave queries to filter the data on the table and focus on one particular row.

An example of logged table data.
An example of logged table data.

Open an image and see how you can interact with each of the segmentation masks using the interactive overlay.

An example of visualized segmentation maps.
*An example of visualized segmentation maps.

Loading the Data

Create the PyTorch DataLoaders for loading the data from the datasets. Before creating the DataLoaders, set the transform for train_dataset to train_transform to pre-process and transform the data for training.

# apply train_transforms to the training dataset
train_dataset.transform = train_transform

# create the train_loader
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
)

# create the val_loader
val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
)

Creating the Model, Loss, and Optimizer

This tutorial crates a SegResNet model based on the paper 3D MRI brain tumor segmentation using auto-encoder regularization. The SegResNet model that comes implemented as a PyTorch Module as part of the monai.networks API as well as an optimizer and learning rate scheduler.

device = torch.device("cuda:0")

# create model
model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=4,
    out_channels=3,
    dropout_prob=0.2,
).to(device)

# create optimizer
optimizer = torch.optim.Adam(
    model.parameters(),
    config.initial_learning_rate,
    weight_decay=config.weight_decay,
)

# create learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=config.max_train_epochs
)

Define the loss as multi-label DiceLoss using the monai.losses API and the corresponding dice metrics using the monai.metrics API.

loss_function = DiceLoss(
    smooth_nr=config.dice_loss_smoothen_numerator,
    smooth_dr=config.dice_loss_smoothen_denominator,
    squared_pred=config.dice_loss_squared_prediction,
    to_onehot_y=config.dice_loss_target_onehot,
    sigmoid=config.dice_loss_apply_sigmoid,
)

dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

# use automatic mixed-precision to accelerate training
scaler = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True

Define a small utility for mixed-precision inference. This will be useful during the validation step of the training process and when you want to run the model after training.

def inference(model, input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(240, 240, 160),
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
        )

    with torch.cuda.amp.autocast():
        return _compute(input)

Training and Validation

Before training, define the metric properties which will later be logged with wandb.log() for tracking the training and validation experiments.

wandb.define_metric("epoch/epoch_step")
wandb.define_metric("epoch/*", step_metric="epoch/epoch_step")
wandb.define_metric("batch/batch_step")
wandb.define_metric("batch/*", step_metric="batch/batch_step")
wandb.define_metric("validation/validation_step")
wandb.define_metric("validation/*", step_metric="validation/validation_step")

batch_step = 0
validation_step = 0
metric_values = []
metric_values_tumor_core = []
metric_values_whole_tumor = []
metric_values_enhanced_tumor = []

Execute Standard PyTorch Training Loop

# Define a W&B Artifact object
artifact = wandb.Artifact(
    name=f"{wandb.run.id}-checkpoint", type="model"
)

epoch_progress_bar = tqdm(range(config.max_train_epochs), desc="Training:")

for epoch in epoch_progress_bar:
    model.train()
    epoch_loss = 0

    total_batch_steps = len(train_dataset) // train_loader.batch_size
    batch_progress_bar = tqdm(train_loader, total=total_batch_steps, leave=False)
    
    # Training Step
    for batch_data in batch_progress_bar:
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
        batch_progress_bar.set_description(f"train_loss: {loss.item():.4f}:")
        ## Log batch-wise training loss to W&B
        wandb.log({"batch/batch_step": batch_step, "batch/train_loss": loss.item()})
        batch_step += 1

    lr_scheduler.step()
    epoch_loss /= total_batch_steps
    ## Log batch-wise training loss and learning rate to W&B
    wandb.log(
        {
            "epoch/epoch_step": epoch,
            "epoch/mean_train_loss": epoch_loss,
            "epoch/learning_rate": lr_scheduler.get_last_lr()[0],
        }
    )
    epoch_progress_bar.set_description(f"Training: train_loss: {epoch_loss:.4f}:")

    # Validation and model checkpointing step
    if (epoch + 1) % config.validation_intervals == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                val_outputs = inference(model, val_inputs)
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                dice_metric(y_pred=val_outputs, y=val_labels)
                dice_metric_batch(y_pred=val_outputs, y=val_labels)

            metric_values.append(dice_metric.aggregate().item())
            metric_batch = dice_metric_batch.aggregate()
            metric_values_tumor_core.append(metric_batch[0].item())
            metric_values_whole_tumor.append(metric_batch[1].item())
            metric_values_enhanced_tumor.append(metric_batch[2].item())
            dice_metric.reset()
            dice_metric_batch.reset()

            checkpoint_path = os.path.join(config.checkpoint_dir, "model.pth")
            torch.save(model.state_dict(), checkpoint_path)
            
            # Log and versison model checkpoints using W&B artifacts.
            artifact.add_file(local_path=checkpoint_path)
            wandb.log_artifact(artifact, aliases=[f"epoch_{epoch}"])

            # Log validation metrics to W&B dashboard.
            wandb.log(
                {
                    "validation/validation_step": validation_step,
                    "validation/mean_dice": metric_values[-1],
                    "validation/mean_dice_tumor_core": metric_values_tumor_core[-1],
                    "validation/mean_dice_whole_tumor": metric_values_whole_tumor[-1],
                    "validation/mean_dice_enhanced_tumor": metric_values_enhanced_tumor[-1],
                }
            )
            validation_step += 1


# Wait for this artifact to finish logging
artifact.wait()

Instrumenting the code with wandb.log not only enables tracking all metrics associated with the training and validation process, but also the all system metrics (our CPU and GPU in this case) on the W&B dashboard.

An example of training and validation process tracking on W&B.
An example of training and validation process tracking on W&B.

Navigate to the artifacts tab in the W&B run dashboard to access the different versions of model checkpoint artifacts logged during training.

An example of model checkpoints logging and versioning on W&B.
An example of model checkpoints logging and versioning on W&B.

Inference

Using the artifacts interface, you can select which version of the artifact is the best model checkpoint, in this case, the mean epoch-wise training loss. You can also explore the entire lineage of the artifact and use the version that you need.

An example of model artifact tracking on W&B.
An example of model artifact tracking on W&B.

Fetch the version of the model artifact with the best epoch-wise mean training loss and load the checkpoint state dictionary to the model.

model_artifact = wandb.use_artifact(
    "geekyrakshit/monai-brain-tumor-segmentation/d5ex6n4a-checkpoint:v49",
    type="model",
)
model_artifact_dir = model_artifact.download()
model.load_state_dict(torch.load(os.path.join(model_artifact_dir, "model.pth")))
model.eval()

Visualizing Predictions and Comparing with the Ground Truth Labels

Create another utility function to visualize the predictions of the pre-trained model and compare them with the corresponding ground-truth segmentation mask using the interactive segmentation mask overlay,.

def log_predictions_into_tables(
    sample_image: np.array,
    sample_label: np.array,
    predicted_label: np.array,
    split: str = None,
    data_idx: int = None,
    table: wandb.Table = None,
):
    num_channels, _, _, num_slices = sample_image.shape
    with tqdm(total=num_slices, leave=False) as progress_bar:
        for slice_idx in range(num_slices):
            wandb_images = []
            for channel_idx in range(num_channels):
                wandb_images += [
                    wandb.Image(
                        sample_image[channel_idx, :, :, slice_idx],
                        masks={
                            "ground-truth/Tumor-Core": {
                                "mask_data": sample_label[0, :, :, slice_idx],
                                "class_labels": {0: "background", 1: "Tumor Core"},
                            },
                            "prediction/Tumor-Core": {
                                "mask_data": predicted_label[0, :, :, slice_idx] * 2,
                                "class_labels": {0: "background", 2: "Tumor Core"},
                            },
                        },
                    ),
                    wandb.Image(
                        sample_image[channel_idx, :, :, slice_idx],
                        masks={
                            "ground-truth/Whole-Tumor": {
                                "mask_data": sample_label[1, :, :, slice_idx],
                                "class_labels": {0: "background", 1: "Whole Tumor"},
                            },
                            "prediction/Whole-Tumor": {
                                "mask_data": predicted_label[1, :, :, slice_idx] * 2,
                                "class_labels": {0: "background", 2: "Whole Tumor"},
                            },
                        },
                    ),
                    wandb.Image(
                        sample_image[channel_idx, :, :, slice_idx],
                        masks={
                            "ground-truth/Enhancing-Tumor": {
                                "mask_data": sample_label[2, :, :, slice_idx],
                                "class_labels": {0: "background", 1: "Enhancing Tumor"},
                            },
                            "prediction/Enhancing-Tumor": {
                                "mask_data": predicted_label[2, :, :, slice_idx] * 2,
                                "class_labels": {0: "background", 2: "Enhancing Tumor"},
                            },
                        },
                    ),
                ]
            table.add_data(split, data_idx, slice_idx, *wandb_images)
            progress_bar.update(1)
    return table

Log the prediction results to the prediction table.

# create the prediction table
prediction_table = wandb.Table(
    columns=[
        "Split",
        "Data Index",
        "Slice Index",
        "Image-Channel-0/Tumor-Core",
        "Image-Channel-1/Tumor-Core",
        "Image-Channel-2/Tumor-Core",
        "Image-Channel-3/Tumor-Core",
        "Image-Channel-0/Whole-Tumor",
        "Image-Channel-1/Whole-Tumor",
        "Image-Channel-2/Whole-Tumor",
        "Image-Channel-3/Whole-Tumor",
        "Image-Channel-0/Enhancing-Tumor",
        "Image-Channel-1/Enhancing-Tumor",
        "Image-Channel-2/Enhancing-Tumor",
        "Image-Channel-3/Enhancing-Tumor",
    ]
)

# Perform inference and visualization
with torch.no_grad():
    config.max_prediction_images_visualized
    max_samples = (
        min(config.max_prediction_images_visualized, len(val_dataset))
        if config.max_prediction_images_visualized > 0
        else len(val_dataset)
    )
    progress_bar = tqdm(
        enumerate(val_dataset[:max_samples]),
        total=max_samples,
        desc="Generating Predictions:",
    )
    for data_idx, sample in progress_bar:
        val_input = sample["image"].unsqueeze(0).to(device)
        val_output = inference(model, val_input)
        val_output = post_trans(val_output[0])
        prediction_table = log_predictions_into_tables(
            sample_image=sample["image"].cpu().numpy(),
            sample_label=sample["label"].cpu().numpy(),
            predicted_label=val_output.cpu().numpy(),
            data_idx=data_idx,
            split="validation",
            table=prediction_table,
        )

    wandb.log({"Predictions/Tumor-Segmentation-Data": prediction_table})


# End the experiment
wandb.finish()

Use the interactive segmentation mask overlay to analyze and compare the predicted segmentation masks and the ground-truth labels for each class.

An example of predictions and ground-truth visualization on W&B.
An example of predictions and ground-truth visualization on W&B.

Acknowledgements and more resources

6.7 - Keras

Use Weights & Biases for machine learning experiment tracking, dataset versioning, and project collaboration.

This Colab notebook introduces the WandbMetricsLogger callback. Use this callback for Experiment Tracking. It will log your training and validation metrics along with system metrics to Weights and Biases.

Setup and Installation

First, let us install the latest version of Weights and Biases. We will then authenticate this colab instance to use W&B.

pip install -qq -U wandb
import os
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import models
import tensorflow_datasets as tfds

# Weights and Biases related imports
import wandb
from wandb.integration.keras import WandbMetricsLogger

If this is your first time using W&B or you are not logged in, the link that appears after running wandb.login() will take you to sign-up/login page. Signing up for a free account is as easy as a few clicks.

wandb.login()

Hyperparameters

Use of proper config system is a recommended best practice for reproducible machine learning. We can track the hyperparameters for every experiment using W&B. In this colab we will be using simple Python dict as our config system.

configs = dict(
    num_classes=10,
    shuffle_buffer=1024,
    batch_size=64,
    image_size=28,
    image_channels=1,
    earlystopping_patience=3,
    learning_rate=1e-3,
    epochs=10,
)

Dataset

In this colab, we will be using CIFAR100 dataset from TensorFlow Dataset catalog. We aim to build a simple image classification pipeline using TensorFlow/Keras.

train_ds, valid_ds = tfds.load("fashion_mnist", split=["train", "test"])
AUTOTUNE = tf.data.AUTOTUNE


def parse_data(example):
    # Get image
    image = example["image"]
    # image = tf.image.convert_image_dtype(image, dtype=tf.float32)

    # Get label
    label = example["label"]
    label = tf.one_hot(label, depth=configs["num_classes"])

    return image, label


def get_dataloader(ds, configs, dataloader_type="train"):
    dataloader = ds.map(parse_data, num_parallel_calls=AUTOTUNE)

    if dataloader_type == "train":
        dataloader = dataloader.shuffle(configs["shuffle_buffer"])

    dataloader = dataloader.batch(configs["batch_size"]).prefetch(AUTOTUNE)

    return dataloader
trainloader = get_dataloader(train_ds, configs)
validloader = get_dataloader(valid_ds, configs, dataloader_type="valid")

Model

def get_model(configs):
    backbone = tf.keras.applications.mobilenet_v2.MobileNetV2(
        weights="imagenet", include_top=False
    )
    backbone.trainable = False

    inputs = layers.Input(
        shape=(configs["image_size"], configs["image_size"], configs["image_channels"])
    )
    resize = layers.Resizing(32, 32)(inputs)
    neck = layers.Conv2D(3, (3, 3), padding="same")(resize)
    preprocess_input = tf.keras.applications.mobilenet.preprocess_input(neck)
    x = backbone(preprocess_input)
    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(configs["num_classes"], activation="softmax")(x)

    return models.Model(inputs=inputs, outputs=outputs)
tf.keras.backend.clear_session()
model = get_model(configs)
model.summary()

Compile Model

model.compile(
    optimizer="adam",
    loss="categorical_crossentropy",
    metrics=[
        "accuracy",
        tf.keras.metrics.TopKCategoricalAccuracy(k=5, name="top@5_accuracy"),
    ],
)

Train

# Initialize a W&B run
run = wandb.init(project="intro-keras", config=configs)

# Train your model
model.fit(
    trainloader,
    epochs=configs["epochs"],
    validation_data=validloader,
    callbacks=[
        WandbMetricsLogger(log_freq=10)
    ],  # Notice the use of WandbMetricsLogger here
)

# Close the W&B run
run.finish()

6.8 - Keras models

Use Weights & Biases for machine learning experiment tracking, dataset versioning, and project collaboration.

This Colab notebook introduces the WandbModelCheckpoint callback. Use this callback to log your model checkpoints to Weight and Biases Artifacts.

Setup and Installation

First, let us install the latest version of Weights and Biases. We will then authenticate this colab instance to use W&B.

!pip install -qq -U wandb
import os
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import models
import tensorflow_datasets as tfds

# Weights and Biases related imports
import wandb
from wandb.integration.keras import WandbMetricsLogger
from wandb.integration.keras import WandbModelCheckpoint

If this is your first time using W&B or you are not logged in, the link that appears after running wandb.login() will take you to sign-up/login page. Signing up for a free account is as easy as a few clicks.

wandb.login()

Hyperparameters

Use of proper config system is a recommended best practice for reproducible machine learning. We can track the hyperparameters for every experiment using W&B. In this colab we will be using simple Python dict as our config system.

configs = dict(
    num_classes = 10,
    shuffle_buffer = 1024,
    batch_size = 64,
    image_size = 28,
    image_channels = 1,
    earlystopping_patience = 3,
    learning_rate = 1e-3,
    epochs = 10
)

Dataset

In this colab, we will be using CIFAR100 dataset from TensorFlow Dataset catalog. We aim to build a simple image classification pipeline using TensorFlow/Keras.

train_ds, valid_ds = tfds.load('fashion_mnist', split=['train', 'test'])
AUTOTUNE = tf.data.AUTOTUNE


def parse_data(example):
    # Get image
    image = example["image"]
    # image = tf.image.convert_image_dtype(image, dtype=tf.float32)

    # Get label
    label = example["label"]
    label = tf.one_hot(label, depth=configs["num_classes"])

    return image, label


def get_dataloader(ds, configs, dataloader_type="train"):
    dataloader = ds.map(parse_data, num_parallel_calls=AUTOTUNE)

    if dataloader_type=="train":
        dataloader = dataloader.shuffle(configs["shuffle_buffer"])
      
    dataloader = (
        dataloader
        .batch(configs["batch_size"])
        .prefetch(AUTOTUNE)
    )

    return dataloader
trainloader = get_dataloader(train_ds, configs)
validloader = get_dataloader(valid_ds, configs, dataloader_type="valid")

Model

def get_model(configs):
    backbone = tf.keras.applications.mobilenet_v2.MobileNetV2(weights='imagenet', include_top=False)
    backbone.trainable = False

    inputs = layers.Input(shape=(configs["image_size"], configs["image_size"], configs["image_channels"]))
    resize = layers.Resizing(32, 32)(inputs)
    neck = layers.Conv2D(3, (3,3), padding="same")(resize)
    preprocess_input = tf.keras.applications.mobilenet.preprocess_input(neck)
    x = backbone(preprocess_input)
    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(configs["num_classes"], activation="softmax")(x)

    return models.Model(inputs=inputs, outputs=outputs)
tf.keras.backend.clear_session()
model = get_model(configs)
model.summary()

Compile Model

model.compile(
    optimizer = "adam",
    loss = "categorical_crossentropy",
    metrics = ["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top@5_accuracy')]
)

Train

# Initialize a W&B run
run = wandb.init(
    project = "intro-keras",
    config = configs
)

# Train your model
model.fit(
    trainloader,
    epochs = configs["epochs"],
    validation_data = validloader,
    callbacks = [
        WandbMetricsLogger(log_freq=10),
        WandbModelCheckpoint(filepath="models/") # Notice the use of WandbModelCheckpoint here
    ]
)

# Close the W&B run
run.finish()

6.9 - Keras tables

Use Weights & Biases for machine learning experiment tracking, dataset versioning, and project collaboration.

This Colab notebook introduces the WandbEvalCallback which is an abstract callback that be inherited to build useful callbacks for model prediction visualization and dataset visualization.

Setup and Installation

First, let us install the latest version of Weights and Biases. We will then authenticate this colab instance to use W&B.

pip install -qq -U wandb
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import models
import tensorflow_datasets as tfds

# Weights and Biases related imports
import wandb
from wandb.integration.keras import WandbMetricsLogger
from wandb.integration.keras import WandbModelCheckpoint
from wandb.integration.keras import WandbEvalCallback

If this is your first time using W&B or you are not logged in, the link that appears after running wandb.login() will take you to sign-up/login page. Signing up for a free account is as easy as a few clicks.

wandb.login()

Hyperparameters

Use of proper config system is a recommended best practice for reproducible machine learning. We can track the hyperparameters for every experiment using W&B. In this colab we will be using simple Python dict as our config system.

configs = dict(
    num_classes=10,
    shuffle_buffer=1024,
    batch_size=64,
    image_size=28,
    image_channels=1,
    earlystopping_patience=3,
    learning_rate=1e-3,
    epochs=10,
)

Dataset

In this colab, we will be using CIFAR100 dataset from TensorFlow Dataset catalog. We aim to build a simple image classification pipeline using TensorFlow/Keras.

train_ds, valid_ds = tfds.load("fashion_mnist", split=["train", "test"])
AUTOTUNE = tf.data.AUTOTUNE


def parse_data(example):
    # Get image
    image = example["image"]
    # image = tf.image.convert_image_dtype(image, dtype=tf.float32)

    # Get label
    label = example["label"]
    label = tf.one_hot(label, depth=configs["num_classes"])

    return image, label


def get_dataloader(ds, configs, dataloader_type="train"):
    dataloader = ds.map(parse_data, num_parallel_calls=AUTOTUNE)

    if dataloader_type=="train":
        dataloader = dataloader.shuffle(configs["shuffle_buffer"])
      
    dataloader = (
        dataloader
        .batch(configs["batch_size"])
        .prefetch(AUTOTUNE)
    )

    return dataloader
trainloader = get_dataloader(train_ds, configs)
validloader = get_dataloader(valid_ds, configs, dataloader_type="valid")

Model

def get_model(configs):
    backbone = tf.keras.applications.mobilenet_v2.MobileNetV2(
        weights="imagenet", include_top=False
    )
    backbone.trainable = False

    inputs = layers.Input(
        shape=(configs["image_size"], configs["image_size"], configs["image_channels"])
    )
    resize = layers.Resizing(32, 32)(inputs)
    neck = layers.Conv2D(3, (3, 3), padding="same")(resize)
    preprocess_input = tf.keras.applications.mobilenet.preprocess_input(neck)
    x = backbone(preprocess_input)
    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(configs["num_classes"], activation="softmax")(x)

    return models.Model(inputs=inputs, outputs=outputs)
tf.keras.backend.clear_session()
model = get_model(configs)
model.summary()

Compile Model

model.compile(
    optimizer="adam",
    loss="categorical_crossentropy",
    metrics=[
        "accuracy",
        tf.keras.metrics.TopKCategoricalAccuracy(k=5, name="top@5_accuracy"),
    ],
)

WandbEvalCallback

The WandbEvalCallback is an abstract base class to build Keras callbacks for primarily model prediction visualization and secondarily dataset visualization.

This is a dataset and task agnostic abstract callback. To use this, inherit from this base callback class and implement the add_ground_truth and add_model_prediction methods.

The WandbEvalCallback is a utility class that provides helpful methods to:

  • create data and prediction wandb.Table instances,
  • log data and prediction Tables as wandb.Artifact,
  • logs the data table on_train_begin,
  • logs the prediction table on_epoch_end.

As an example, we have implemented WandbClfEvalCallback below for an image classification task. This example callback:

  • logs the validation data (data_table) to W&B,
  • performs inference and logs the prediction (pred_table) to W&B on every epoch end.

How the memory footprint is reduced

We log the data_table to W&B when the on_train_begin method is ivoked. Once it’s uploaded as a W&B Artifact, we get a reference to this table which can be accessed using data_table_ref class variable. The data_table_ref is a 2D list that can be indexed like self.data_table_ref[idx][n] where idx is the row number while n is the column number. Let’s see the usage in the example below.

class WandbClfEvalCallback(WandbEvalCallback):
    def __init__(
        self, validloader, data_table_columns, pred_table_columns, num_samples=100
    ):
        super().__init__(data_table_columns, pred_table_columns)

        self.val_data = validloader.unbatch().take(num_samples)

    def add_ground_truth(self, logs=None):
        for idx, (image, label) in enumerate(self.val_data):
            self.data_table.add_data(idx, wandb.Image(image), np.argmax(label, axis=-1))

    def add_model_predictions(self, epoch, logs=None):
        # Get predictions
        preds = self._inference()
        table_idxs = self.data_table_ref.get_index()

        for idx in table_idxs:
            pred = preds[idx]
            self.pred_table.add_data(
                epoch,
                self.data_table_ref.data[idx][0],
                self.data_table_ref.data[idx][1],
                self.data_table_ref.data[idx][2],
                pred,
            )

    def _inference(self):
        preds = []
        for image, label in self.val_data:
            pred = self.model(tf.expand_dims(image, axis=0))
            argmax_pred = tf.argmax(pred, axis=-1).numpy()[0]
            preds.append(argmax_pred)

        return preds

Train

# Initialize a W&B run
run = wandb.init(project="intro-keras", config=configs)

# Train your model
model.fit(
    trainloader,
    epochs=configs["epochs"],
    validation_data=validloader,
    callbacks=[
        WandbMetricsLogger(log_freq=10),
        WandbClfEvalCallback(
            validloader,
            data_table_columns=["idx", "image", "ground_truth"],
            pred_table_columns=["epoch", "idx", "image", "ground_truth", "prediction"],
        ),  # Notice the use of WandbEvalCallback here
    ],
)

# Close the W&B run
run.finish()

6.10 - XGBoost Sweeps

Use Weights & Biases for machine learning experiment tracking, dataset versioning, and project collaboration.

Squeezing the best performance out of tree-based models requires selecting the right hyperparameters. How many early_stopping_rounds? What should the max_depth of a tree be?

Searching through high dimensional hyperparameter spaces to find the most performant model can get unwieldy very fast. Hyperparameter sweeps provide an organized and efficient way to conduct a battle royale of models and crown a winner. They enable this by automatically searching through combinations of hyperparameter values to find the most optimal values.

In this tutorial we’ll see how you can run sophisticated hyperparameter sweeps on XGBoost models in 3 easy steps using Weights and Biases.

For a teaser, check out the plots below:

sweeps_xgboost

Sweeps: An Overview

Running a hyperparameter sweep with Weights & Biases is very easy. There are just 3 simple steps:

  1. Define the sweep: we do this by creating a dictionary-like object that specifies the sweep: which parameters to search through, which search strategy to use, which metric to optimize.

  2. Initialize the sweep: with one line of code we initialize the sweep and pass in the dictionary of sweep configurations: sweep_id = wandb.sweep(sweep_config)

  3. Run the sweep agent: also accomplished with one line of code, we call wandb.agent() and pass the sweep_id along with a function that defines your model architecture and trains it: wandb.agent(sweep_id, function=train)

That’s all there is to running a hyperparameter sweep.

In the notebook below, we’ll walk through these 3 steps in more detail.

We highly encourage you to fork this notebook, tweak the parameters, or try the model with your own dataset.

Resources

!pip install wandb -qU

import wandb
wandb.login()

1. Define the Sweep

Weights & Biases sweeps give you powerful levers to configure your sweeps exactly how you want them, with just a few lines of code. The sweeps config can be defined as a dictionary or a YAML file.

Let’s walk through some of them together:

  • Metric: This is the metric the sweeps are attempting to optimize. Metrics can take a name (this metric should be logged by your training script) and a goal (maximize or minimize).
  • Search Strategy: Specified using the "method" key. We support several different search strategies with sweeps.
  • Grid Search: Iterates over every combination of hyperparameter values.
  • Random Search: Iterates over randomly chosen combinations of hyperparameter values.
  • Bayesian Search: Creates a probabilistic model that maps hyperparameters to probability of a metric score, and chooses parameters with high probability of improving the metric. The objective of Bayesian optimization is to spend more time in picking the hyperparameter values, but in doing so trying out fewer hyperparameter values.
  • Parameters: A dictionary containing the hyperparameter names, and discrete values, a range, or distributions from which to pull their values on each iteration.

For details, see the list of all sweep configuration options.

sweep_config = {
    "method": "random", # try grid or random
    "metric": {
      "name": "accuracy",
      "goal": "maximize"   
    },
    "parameters": {
        "booster": {
            "values": ["gbtree","gblinear"]
        },
        "max_depth": {
            "values": [3, 6, 9, 12]
        },
        "learning_rate": {
            "values": [0.1, 0.05, 0.2]
        },
        "subsample": {
            "values": [1, 0.5, 0.3]
        }
    }
}

2. Initialize the Sweep

Calling wandb.sweep starts a Sweep Controller – a centralized process that provides settings of the parameters to any who query it and expects them to return performance on metrics via wandb logging.

sweep_id = wandb.sweep(sweep_config, project="XGBoost-sweeps")

Define your training process

Before we can run the sweep, we need to define a function that creates and trains the model – the function that takes in hyperparameter values and spits out metrics.

We’ll also need wandb to be integrated into our script. There’s three main components:

  • wandb.init(): Initialize a new W&B run. Each run is single execution of the training script.
  • wandb.config: Save all your hyperparameters in a config object. This lets you use our app to sort and compare your runs by hyperparameter values.
  • wandb.log(): Logs metrics and custom objects, such as images, videos, audio files, HTML, plots, or point clouds.

We also need to download the data:

!wget https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv
# XGBoost model for Pima Indians dataset
from numpy import loadtxt
from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# load data
def train():
  config_defaults = {
    "booster": "gbtree",
    "max_depth": 3,
    "learning_rate": 0.1,
    "subsample": 1,
    "seed": 117,
    "test_size": 0.33,
  }

  wandb.init(config=config_defaults)  # defaults are over-ridden during the sweep
  config = wandb.config

  # load data and split into predictors and targets
  dataset = loadtxt("pima-indians-diabetes.data.csv", delimiter=",")
  X, Y = dataset[:, :8], dataset[:, 8]

  # split data into train and test sets
  X_train, X_test, y_train, y_test = train_test_split(X, Y,
                                                      test_size=config.test_size,
                                                      random_state=config.seed)

  # fit model on train
  model = XGBClassifier(booster=config.booster, max_depth=config.max_depth,
                        learning_rate=config.learning_rate, subsample=config.subsample)
  model.fit(X_train, y_train)

  # make predictions on test
  y_pred = model.predict(X_test)
  predictions = [round(value) for value in y_pred]

  # evaluate predictions
  accuracy = accuracy_score(y_test, predictions)
  print(f"Accuracy: {accuracy:.0%}")
  wandb.log({"accuracy": accuracy})

3. Run the Sweep with an agent

Now, we call wandb.agent to start up our sweep.

You can call wandb.agent on any machine where you’re logged into W&B that has

  • the sweep_id,
  • the dataset and train function

and that machine will join the sweep.

Note: a random sweep will by defauly run forever, trying new parameter combinations until the cows come home – or until you turn the sweep off from the app UI. You can prevent this by providing the total count of runs you’d like the agent to complete.

wandb.agent(sweep_id, train, count=25)

Visualize your results

Now that your sweep is finished, it’s time to look at the results.

Weights & Biases will generate a number of useful plots for you automatically.

Parallel coordinates plot

This plot maps hyperparameter values to model metrics. It’s useful for honing in on combinations of hyperparameters that led to the best model performance.

This plot seems to indicate that using a tree as our learner slightly, but not mind-blowingly, outperforms using a simple linear model as our learner.

sweeps_xgboost

Hyperparameter importance plot

The hyperparameter importance plot shows which hyperparameter values had the biggest impact on your metrics.

We report both the correlation (treating it as a linear predictor) and the feature importance (after training a random forest on your results) so you can see which parameters had the biggest effect and whether that effect was positive or negative.

Reading this chart, we see quantitative confirmation of the trend we noticed in the parallel coordinates chart above: the largest impact on validation accuracy came from the choice of learner, and the gblinear learners were generally worse than gbtree learners.

sweeps_xgboost

These visualizations can help you save both time and resources running expensive hyperparameter optimizations by honing in on the parameters (and value ranges) that are the most important, and thereby worthy of further exploration.

7 - Weave and Models integration demo

This notebook shows how to use W&B Weave together with W&B Models. Specifically, this example considers two different teams.

  • The Model Team: the model building team fine-tunes a new Chat Model (Llama 3.2) and saves it to the registry using W&B Models.
  • The App Team: the app development team retrieves the Chat Model to create and evaluate a new RAG chatbot using W&B Weave.

Find the public workspace for both W&B Models and W&B Weave here.

Weights & Biases

The workflow covers the following steps:

  1. Instrument the RAG app code with W&B Weave
  2. Fine-tune an LLM (such as Llama 3.2, but you can replace it with any other LLM) and track it with W&B Models
  3. Log the fine-tuned model to the W&B Registry
  4. Implement the RAG app with the new fine-tuned model and evaluate the app with W&B Weave
  5. Once satisfied with the results, save a reference to the updated Rag app in the W&B Registry

Note:

The RagModel referenced below is top-level weave.Model that you can consider a complete RAG app. It contains a ChatModel, Vector database, and a Prompt. The ChatModel is also another weave.Model which contains the code to download an artifact from the W&B Registry and it can change to support any other chat model as part of the RagModel. For more details see the complete model on Weave.

1. Setup

First, install weave and wandb, then log in with an API key. You can create and view your API keys at https://wandb.ai/settings.

pip install weave wandb
import wandb
import weave
import pandas as pd

PROJECT = "weave-cookboook-demo"
ENTITY = "wandb-smle"

wandb.login()
weave.init(ENTITY + "/" + PROJECT)

2. Make ChatModel based on Artifact

Retrieve the fine-tuned chat model from the Registry and create a weave.Model from it to directly plug into the RagModel in the next step. It takes in the same parameters as the existing ChatModel just the init and predict change.

pip install unsloth
pip uninstall unsloth -y && pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

The model team fine-tuned different Llama-3.2 models using the unsloth library to make it faster. Hence use the special unsloth.FastLanguageModel or peft.AutoPeftModelForCausalLM models with adapters to load in the model once downloaded from the Registry. Copy the loading code from the “Use” tab in the Registry and paste it into model_post_init.

import weave
from pydantic import PrivateAttr
from typing import Any, List, Dict, Optional
from unsloth import FastLanguageModel
import torch


class UnslothLoRAChatModel(weave.Model):
    """
    Define an extra ChatModel class to store and version more parameters than just the model name.
    This enables fine-tuning on specific parameters.
    """

    chat_model: str
    cm_temperature: float
    cm_max_new_tokens: int
    cm_quantize: bool
    inference_batch_size: int
    dtype: Any
    device: str
    _model: Any = PrivateAttr()
    _tokenizer: Any = PrivateAttr()

    def model_post_init(self, __context):
        # paste this from the "Use" tab from the registry
        run = wandb.init(project=PROJECT, job_type="model_download")
        artifact = run.use_artifact(f"{self.chat_model}")
        model_path = artifact.download()

        # unsloth version (enable native 2x faster inference)
        self._model, self._tokenizer = FastLanguageModel.from_pretrained(
            model_name=model_path,
            max_seq_length=self.cm_max_new_tokens,
            dtype=self.dtype,
            load_in_4bit=self.cm_quantize,
        )
        FastLanguageModel.for_inference(self._model)

    @weave.op()
    async def predict(self, query: List[str]) -> dict:
        # add_generation_prompt = true - Must add for generation
        input_ids = self._tokenizer.apply_chat_template(
            query,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
        ).to("cuda")

        output_ids = self._model.generate(
            input_ids=input_ids,
            max_new_tokens=64,
            use_cache=True,
            temperature=1.5,
            min_p=0.1,
        )

        decoded_outputs = self._tokenizer.batch_decode(
            output_ids[0][input_ids.shape[1] :], skip_special_tokens=True
        )

        return "".join(decoded_outputs).strip()

Now create a new model with a specific link from the registry:

MODEL_REG_URL = "wandb32/wandb-registry-RAG Chat Models/Finetuned Llama-3.2:v3"

max_seq_length = 2048
dtype = None
load_in_4bit = True

new_chat_model = UnslothLoRAChatModel(
    name="UnslothLoRAChatModelRag",
    chat_model=MODEL_REG_URL,
    cm_temperature=1.0,
    cm_max_new_tokens=max_seq_length,
    cm_quantize=load_in_4bit,
    inference_batch_size=max_seq_length,
    dtype=dtype,
    device="auto",
)

And finally run the evaluation asynchronously:

await new_chat_model.predict(
    [{"role": "user", "content": "What is the capital of Germany?"}]
)

3. Integrate new ChatModel version into RagModel

Building a RAG app from a fine-tuned chat model can provide several advantages, particularly in enhancing the performance and versatility of conversational AI systems.

Now retrieve the RagModel (you can fetch the weave ref for the current RagModel from the use tab as shown in the image below) from the existing Weave project and exchange the ChatModel to the new one. There is no need to change or re-create any of the other components (VDB, prompts, etc.)!

Weights & Biases
pip install litellm faiss-gpu
RagModel = weave.ref(
    "weave:///wandb-smle/weave-cookboook-demo/object/RagModel:cqRaGKcxutBWXyM0fCGTR1Yk2mISLsNari4wlGTwERo"
).get()
# MAGIC: exchange chat_model and publish new version (no need to worry about other RAG components)
RagModel.chat_model = new_chat_model
# First publish the new version so that it is referenced during predictions
PUB_REFERENCE = weave.publish(RagModel, "RagModel")
await RagModel.predict("When was the first conference on climate change?")

4. Run new weave.Evaluation connecting to the existing models run

Finally, evaluate the new RagModel on the existing weave.Evaluation. To make the integration as easy as possible, include the following changes.

From a Models perspective:

  • Getting the model from the registry creates a new wandb.run which is part of the E2E lineage of the chat model
  • Add the Trace ID (with current eval ID) to the run config so that the model team can click the link to go to the corresponding Weave page

From a Weave perspective:

  • Save the artifact / registry link as input to the ChatModel (that is RagModel)
  • Save the run.id as extra column in the traces with weave.attributes
# MAGIC: get an evaluation with a eval dataset and scorers and use them
WEAVE_EVAL = "weave:///wandb-smle/weave-cookboook-demo/object/climate_rag_eval:ntRX6qn3Tx6w3UEVZXdhIh1BWGh7uXcQpOQnIuvnSgo"
climate_rag_eval = weave.ref(WEAVE_EVAL).get()

with weave.attributes({"wandb-run-id": wandb.run.id}):
    # use .call attribute to retrieve both the result and the call in order to save eval trace to Models
    summary, call = await climate_rag_eval.evaluate.call(climate_rag_eval, ` RagModel `)

5. Save the new RAG model on the Registry

In order to effectively share the new RAG Model, push it to the Registry as a reference artifact adding in the weave version as an alias.

MODELS_OBJECT_VERSION = PUB_REFERENCE.digest  # weave object version
MODELS_OBJECT_NAME = PUB_REFERENCE.name  # weave object name

models_url = f"https://wandb.ai/{ENTITY}/{PROJECT}/weave/objects/{MODELS_OBJECT_NAME}/versions/{MODELS_OBJECT_VERSION}"
models_link = (
    f"weave:///{ENTITY}/{PROJECT}/object/{MODELS_OBJECT_NAME}:{MODELS_OBJECT_VERSION}"
)

with wandb.init(project=PROJECT, entity=ENTITY) as run:
    # create new Artifact
    artifact_model = wandb.Artifact(
        name="RagModel",
        type="model",
        description="Models Link from RagModel in Weave",
        metadata={"url": models_url},
    )
    artifact_model.add_reference(models_link, name="model", checksum=False)

    # log new artifact
    run.log_artifact(artifact_model, aliases=[MODELS_OBJECT_VERSION])

    # link to registry
    run.link_artifact(
        artifact_model, target_path="wandb32/wandb-registry-RAG Models/RAG Model"
    )