This is the multi-page printable view of this section. Click here to print.

Return to the regular view of this page.

Keras

Tools for integrating wandb with Keras.

Classes

class WandbCallback: WandbCallback automatically integrates keras with wandb.

class WandbEvalCallback: Abstract base class to build Keras callbacks for model prediction visualization.

class WandbMetricsLogger: Logger that sends system metrics to W&B.

class WandbModelCheckpoint: A checkpoint that periodically saves a Keras model or model weights.

1 - keras

Tools for integrating wandb with Keras.

Classes

class WandbCallback: WandbCallback automatically integrates keras with wandb.

class WandbEvalCallback: Abstract base class to build Keras callbacks for model prediction visualization.

class WandbMetricsLogger: Logger that sends system metrics to W&B.

class WandbModelCheckpoint: A checkpoint that periodically saves a Keras model or model weights.

2 - WandbCallback

WandbCallback automatically integrates keras with wandb.

WandbCallback(
    monitor="val_loss", verbose=0, mode="auto", save_weights_only=(False),
    log_weights=(False), log_gradients=(False), save_model=(True),
    training_data=None, validation_data=None, labels=None, predictions=36,
    generator=None, input_type=None, output_type=None, log_evaluation=(False),
    validation_steps=None, class_colors=None, log_batch_frequency=None,
    log_best_prefix="best_", save_graph=(True), validation_indexes=None,
    validation_row_processor=None, prediction_row_processor=None,
    infer_missing_processors=(True), log_evaluation_frequency=0,
    compute_flops=(False), **kwargs
)

Example:

model.fit(
    X_train,
    y_train,
    validation_data=(X_test, y_test),
    callbacks=[WandbCallback()],
)

WandbCallback will automatically log history data from any metrics collected by keras: loss and anything passed into keras_model.compile().

WandbCallback will set summary metrics for the run associated with the “best” training step, where “best” is defined by the monitor and mode attributes. This defaults to the epoch with the minimum val_loss. WandbCallback will by default save the model associated with the best epoch.

WandbCallback can optionally log gradient and parameter histograms.

WandbCallback can optionally save training and validation data for wandb to visualize.

Args
monitor (str) name of metric to monitor. Defaults to val_loss.
mode (str) one of {auto, min, max}. min - save model when monitor is minimized max - save model when monitor is maximized auto - try to guess when to save the model (default).
save_model True - save a model when monitor beats all previous epochs False - don’t save models
save_graph (boolean) if True save model graph to wandb (default to True).
save_weights_only (boolean) if True, then only the model’s weights will be saved (model.save_weights(filepath)), else the full model is saved (model.save(filepath)).
log_weights (boolean) if True save histograms of the model’s layer’s weights.
log_gradients (boolean) if True log histograms of the training gradients
training_data (tuple) Same format (X,y) as passed to model.fit. This is needed for calculating gradients - this is mandatory if log_gradients is True.
validation_data (tuple) Same format (X,y) as passed to model.fit. A set of data for wandb to visualize. If this is set, every epoch, wandb will make a small number of predictions and save the results for later visualization. In case you are working with image data, please also set input_type and output_type in order to log correctly.
generator (generator) a generator that returns validation data for wandb to visualize. This generator should return tuples (X,y). Either validate_data or generator should be set for wandb to visualize specific data examples. In case you are working with image data, please also set input_type and output_type in order to log correctly.
validation_steps (int) if validation_data is a generator, how many steps to run the generator for the full validation set.
labels (list) If you are visualizing your data with wandb this list of labels will convert numeric output to understandable string if you are building a multiclass classifier. If you are making a binary classifier you can pass in a list of two labels [label for false, label for true]. If validate_data and generator are both false, this won’t do anything.
predictions (int) the number of predictions to make for visualization each epoch, max is 100.
input_type (string) type of the model input to help visualization. can be one of: (image, images, segmentation_mask, auto).
output_type (string) type of the model output to help visualization. can be one of: (image, images, segmentation_mask, label).
log_evaluation (boolean) if True, save a Table containing validation data and the model’s predictions at each epoch. See validation_indexes, validation_row_processor, and output_row_processor for additional details.
class_colors ([float, float, float]) if the input or output is a segmentation mask, an array containing an rgb tuple (range 0-1) for each class.
log_batch_frequency (integer) if None, callback will log every epoch. If set to integer, callback will log training metrics every log_batch_frequency batches.
log_best_prefix (string) if None, no extra summary metrics will be saved. If set to a string, the monitored metric and epoch will be prepended with this value and stored as summary metrics.
validation_indexes ([wandb.data_types._TableLinkMixin]) an ordered list of index keys to associate with each validation example. If log_evaluation is True and validation_indexes is provided, then a Table of validation data will not be created and instead each prediction will be associated with the row represented by the TableLinkMixin. The most common way to obtain such keys are is use Table.get_index() which will return a list of row keys.
validation_row_processor (Callable) a function to apply to the validation data, commonly used to visualize the data. The function will receive an ndx (int) and a row (dict). If your model has a single input, then row["input"] will be the input data for the row. Else, it will be keyed based on the name of the input slot. If your fit function takes a single target, then row["target"] will be the target data for the row. Else, it will be keyed based on the name of the output slots. For example, if your input data is a single ndarray, but you wish to visualize the data as an Image, then you can provide lambda ndx, row: {"img": wandb.Image(row["input"])} as the processor. Ignored if log_evaluation is False or validation_indexes are present.
output_row_processor (Callable) same as validation_row_processor, but applied to the model’s output. row["output"] will contain the results of the model output.
infer_missing_processors (bool) Determines if validation_row_processor and output_row_processor should be inferred if missing. Defaults to True. If labels are provided, we will attempt to infer classification-type processors where appropriate.
log_evaluation_frequency (int) Determines the frequency which evaluation results will be logged. Default 0 (only at the end of training). Set to 1 to log every epoch, 2 to log every other epoch, and so on. Has no effect when log_evaluation is False.
compute_flops (bool) Compute the FLOPs of your Keras Sequential or Functional model in GigaFLOPs unit.

Methods

get_flops

View source

get_flops() -> float

Calculate FLOPS [GFLOPs] for a tf.keras.Model or tf.keras.Sequential model in inference mode.

It uses tf.compat.v1.profiler under the hood.

set_model

View source

set_model(
    model
)

set_params

View source

set_params(
    params
)

3 - WandbEvalCallback

Abstract base class to build Keras callbacks for model prediction visualization.

WandbEvalCallback(
    data_table_columns: List[str],
    pred_table_columns: List[str],
    *args,
    **kwargs
) -> None

You can build callbacks for visualizing model predictions on_epoch_end that can be passed to model.fit() for classification, object detection, segmentation, etc. tasks.

To use this, inherit from this base callback class and implement the add_ground_truth and add_model_prediction methods.

The base class will take care of the following:

  • Initialize data_table for logging the ground truth and pred_table for predictions.
  • The data uploaded to data_table is used as a reference for the pred_table. This is to reduce the memory footprint. The data_table_ref is a list that can be used to access the referenced data. Check out the example below to see how it’s done.
  • Log the tables to W&B as W&B Artifacts.
  • Each new pred_table is logged as a new version with aliases.

Example:

class WandbClfEvalCallback(WandbEvalCallback):
    def __init__(self, validation_data, data_table_columns, pred_table_columns):
        super().__init__(data_table_columns, pred_table_columns)

        self.x = validation_data[0]
        self.y = validation_data[1]

    def add_ground_truth(self):
        for idx, (image, label) in enumerate(zip(self.x, self.y)):
            self.data_table.add_data(idx, wandb.Image(image), label)

    def add_model_predictions(self, epoch):
        preds = self.model.predict(self.x, verbose=0)
        preds = tf.argmax(preds, axis=-1)

        data_table_ref = self.data_table_ref
        table_idxs = data_table_ref.get_index()

        for idx in table_idxs:
            pred = preds[idx]
            self.pred_table.add_data(
                epoch,
                data_table_ref.data[idx][0],
                data_table_ref.data[idx][1],
                data_table_ref.data[idx][2],
                pred,
            )


model.fit(
    x,
    y,
    epochs=2,
    validation_data=(x, y),
    callbacks=[
        WandbClfEvalCallback(
            validation_data=(x, y),
            data_table_columns=["idx", "image", "label"],
            pred_table_columns=["epoch", "idx", "image", "label", "pred"],
        )
    ],
)

To have more fine-grained control, you can override the on_train_begin and on_epoch_end methods. If you want to log the samples after N batched, you can implement on_train_batch_end method.

Methods

add_ground_truth

View source

@abc.abstractmethod
add_ground_truth(
    logs: Optional[Dict[str, float]] = None
) -> None

Add ground truth data to data_table.

Use this method to write the logic for adding validation/training data to data_table initialized using init_data_table method.

Example:

for idx, data in enumerate(dataloader):
    self.data_table.add_data(idx, data)

This method is called once on_train_begin or equivalent hook.

add_model_predictions

View source

@abc.abstractmethod
add_model_predictions(
    epoch: int,
    logs: Optional[Dict[str, float]] = None
) -> None

Add a prediction from a model to pred_table.

Use this method to write the logic for adding model prediction for validation/ training data to pred_table initialized using init_pred_table method.

Example:

# Assuming the dataloader is not shuffling the samples.
for idx, data in enumerate(dataloader):
    preds = model.predict(data)
    self.pred_table.add_data(
        self.data_table_ref.data[idx][0],
        self.data_table_ref.data[idx][1],
        preds,
    )

This method is called on_epoch_end or equivalent hook.

init_data_table

View source

init_data_table(
    column_names: List[str]
) -> None

Initialize the W&B Tables for validation data.

Call this method on_train_begin or equivalent hook. This is followed by adding data to the table row or column wise.

Args
column_names (list) Column names for W&B Tables.

init_pred_table

View source

init_pred_table(
    column_names: List[str]
) -> None

Initialize the W&B Tables for model evaluation.

Call this method on_epoch_end or equivalent hook. This is followed by adding data to the table row or column wise.

Args
column_names (list) Column names for W&B Tables.

log_data_table

View source

log_data_table(
    name: str = "val",
    type: str = "dataset",
    table_name: str = "val_data"
) -> None

Log the data_table as W&B artifact and call use_artifact on it.

This lets the evaluation table use the reference of already uploaded data (images, text, scalar, etc.) without re-uploading.

Args
name (str) A human-readable name for this artifact, which is how you can identify this artifact in the UI or reference it in use_artifact calls. (default is ‘val’)
type (str) The type of the artifact, which is used to organize and differentiate artifacts. (default is ‘dataset’)
table_name (str) The name of the table as will be displayed in the UI. (default is ‘val_data’).

log_pred_table

View source

log_pred_table(
    type: str = "evaluation",
    table_name: str = "eval_data",
    aliases: Optional[List[str]] = None
) -> None

Log the W&B Tables for model evaluation.

The table will be logged multiple times creating new version. Use this to compare models at different intervals interactively.

Args
type (str) The type of the artifact, which is used to organize and differentiate artifacts. (default is ’evaluation’)
table_name (str) The name of the table as will be displayed in the UI. (default is ’eval_data')
aliases (List[str]) List of aliases for the prediction table.

set_model

set_model(
    model
)

set_params

set_params(
    params
)

4 - WandbMetricsLogger

Logger that sends system metrics to W&B.

WandbMetricsLogger(
    log_freq: Union[LogStrategy, int] = "epoch",
    initial_global_step: int = 0,
    *args,
    **kwargs
) -> None

WandbMetricsLogger automatically logs the logs dictionary that callback methods take as argument to wandb.

This callback automatically logs the following to a W&B run page:

  • system (CPU/GPU/TPU) metrics,
  • train and validation metrics defined in model.compile,
  • learning rate (both for a fixed value or a learning rate scheduler)

Notes:

If you resume training by passing initial_epoch to model.fit and you are using a learning rate scheduler, make sure to pass initial_global_step to WandbMetricsLogger. The initial_global_step is step_size * initial_step, where step_size is number of training steps per epoch. step_size can be calculated as the product of the cardinality of the training dataset and the batch size.

Args
log_freq (epoch, batch, or an int) if epoch, logs metrics at the end of each epoch. If batch, logs metrics at the end of each batch. If an int, logs metrics at the end of that many batches. Defaults to epoch.
initial_global_step (int) Use this argument to correctly log the learning rate when you resume training from some initial_epoch, and a learning rate scheduler is used. This can be computed as step_size * initial_step. Defaults to 0.

Methods

set_model

set_model(
    model
)

set_params

set_params(
    params
)

5 - WandbModelCheckpoint

A checkpoint that periodically saves a Keras model or model weights.

WandbModelCheckpoint(
    filepath: StrPath,
    monitor: str = "val_loss",
    verbose: int = 0,
    save_best_only: bool = (False),
    save_weights_only: bool = (False),
    mode: Mode = "auto",
    save_freq: Union[SaveStrategy, int] = "epoch",
    initial_value_threshold: Optional[float] = None,
    **kwargs
) -> None

Saved weights are uploaded to W&B as a wandb.Artifact.

Since this callback is subclassed from tf.keras.callbacks.ModelCheckpoint, the checkpointing logic is taken care of by the parent callback. You can learn more here: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint

This callback is to be used in conjunction with training using model.fit() to save a model or weights (in a checkpoint file) at some interval. The model checkpoints will be logged as W&B Artifacts. You can learn more here: https://docs.wandb.ai/guides/core/artifacts

This callback provides the following features:

  • Save the model that has achieved best performance based on monitor.
  • Save the model at the end of every epoch regardless of the performance.
  • Save the model at the end of epoch or after a fixed number of training batches.
  • Save only model weights, or save the whole model.
  • Save the model either in SavedModel format or in .h5 format.
Args
filepath (Union[str, os.PathLike]) path to save the model file. filepath can contain named formatting options, which will be filled by the value of epoch and keys in logs (passed in on_epoch_end). For example: if filepath is model-{epoch:02d}-{val_loss:.2f}, then the model checkpoints will be saved with the epoch number and the validation loss in the filename.
monitor (str) The metric name to monitor. Default to val_loss.
verbose (int) Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1 displays messages when the callback takes an action.
save_best_only (bool) if save_best_only=True, it only saves when the model is considered the “best” and the latest best model according to the quantity monitored will not be overwritten. If filepath doesn’t contain formatting options like {epoch} then filepath will be overwritten by each new better model locally. The model logged as an artifact will still be associated with the correct monitor. Artifacts will be uploaded continuously and versioned separately
as a new best model is found.
save_weights_only (bool) if True, then only the model’s weights will be saved.
mode (Mode) one of {‘auto’, ‘min’, ‘max’}. For val_acc, this should be max, for val_loss this should be `mi
n`, etc.
save_freq (Union[SaveStrategy, int]) epoch or integer. When using 'epoch', the callback saves the model after
each epoch. When using an integer, the callback saves the model at end of this many batches. Note that when monitoring
validation metrics such as val_acc or val_loss, save_freq must be set to “epoch” as those metrics are only available
at the end of an epoch.
initial_value_threshold (Optional[float]) Floating point initial “best” value of the metric to be monitored.
Attributes

Methods

set_model

set_model(
    model
)

set_params

set_params(
    params
)