WandbEvalCallback
Abstract base class to build Keras callbacks for model prediction visualization.
WandbEvalCallback(
data_table_columns: List[str],
pred_table_columns: List[str],
*args,
**kwargs
) -> None
You can build callbacks for visualizing model predictions on_epoch_end
that can be passed to model.fit()
for classification, object detection,
segmentation, etc. tasks.
To use this, inherit from this base callback class and implement the
add_ground_truth
and add_model_prediction
methods.
The base class will take care of the following:
- Initialize
data_table
for logging the ground truth andpred_table
for predictions. - The data uploaded to
data_table
is used as a reference for thepred_table
. This is to reduce the memory footprint. Thedata_table_ref
is a list that can be used to access the referenced data. Check out the example below to see how it's done. - Log the tables to W&B as W&B Artifacts.
- Each new
pred_table
is logged as a new version with aliases.
Example:
class WandbClfEvalCallback(WandbEvalCallback):
def __init__(self, validation_data, data_table_columns, pred_table_columns):
super().__init__(data_table_columns, pred_table_columns)
self.x = validation_data[0]
self.y = validation_data[1]
def add_ground_truth(self):
for idx, (image, label) in enumerate(zip(self.x, self.y)):
self.data_table.add_data(idx, wandb.Image(image), label)
def add_model_predictions(self, epoch):
preds = self.model.predict(self.x, verbose=0)
preds = tf.argmax(preds, axis=-1)
data_table_ref = self.data_table_ref
table_idxs = data_table_ref.get_index()
for idx in table_idxs:
pred = preds[idx]
self.pred_table.add_data(
epoch,
data_table_ref.data[idx][0],
data_table_ref.data[idx][1],
data_table_ref.data[idx][2],
pred,
)
model.fit(
x,
y,
epochs=2,
validation_data=(x, y),
callbacks=[
WandbClfEvalCallback(
validation_data=(x, y),
data_table_columns=["idx", "image", "label"],
pred_table_columns=["epoch", "idx", "image", "label", "pred"],
)
],
)
To have more fine-grained control, you can override the on_train_begin
and
on_epoch_end
methods. If you want to log the samples after N batched, you
can implement on_train_batch_end
method.
Methods
add_ground_truth
@abc.abstractmethod
add_ground_truth(
logs: Optional[Dict[str, float]] = None
) -> None
Add ground truth data to data_table
.
Use this method to write the logic for adding validation/training data to
data_table
initialized using init_data_table
method.
Example:
for idx, data in enumerate(dataloader):
self.data_table.add_data(idx, data)
This method is called once on_train_begin
or equivalent hook.
add_model_predictions
@abc.abstractmethod
add_model_predictions(
epoch: int,
logs: Optional[Dict[str, float]] = None
) -> None
Add a prediction from a model to pred_table
.
Use this method to write the logic for adding model prediction for validation/
training data to pred_table
initialized using init_pred_table
method.
Example:
# Assuming the dataloader is not shuffling the samples.
for idx, data in enumerate(dataloader):
preds = model.predict(data)
self.pred_table.add_data(
self.data_table_ref.data[idx][0],
self.data_table_ref.data[idx][1],
preds,
)
This method is called on_epoch_end
or equivalent hook.
init_data_table
init_data_table(
column_names: List[str]
) -> None
Initialize the W&B Tables for validation data.
Call this method on_train_begin
or equivalent hook. This is followed by adding
data to the table row or column wise.
Args | |
---|---|
column_names | (list) Column names for W&B Tables. |
init_pred_table
init_pred_table(
column_names: List[str]
) -> None
Initialize the W&B Tables for model evaluation.
Call this method on_epoch_end
or equivalent hook. This is followed by adding
data to the table row or column wise.
Args | |
---|---|
column_names | (list) Column names for W&B Tables. |
log_data_table
log_data_table(
name: str = "val",
type: str = "dataset",
table_name: str = "val_data"
) -> None
Log the data_table
as W&B artifact and call use_artifact
on it.
This lets the evaluation table use the reference of already uploaded data (images, text, scalar, etc.) without re-uploading.
Args | |
---|---|
name | (str) A human-readable name for this artifact, which is how you can identify this artifact in the UI or reference it in use_artifact calls. (default is 'val') |
type | (str) The type of the artifact, which is used to organize and differentiate artifacts. (default is 'dataset') |
table_name | (str) The name of the table as will be displayed in the UI. (default is 'val_data'). |
log_pred_table
log_pred_table(
type: str = "evaluation",
table_name: str = "eval_data",
aliases: Optional[List[str]] = None
) -> None
Log the W&B Tables for model evaluation.
The table will be logged multiple times creating new version. Use this to compare models at different intervals interactively.
Args | |
---|---|
type | (str) The type of the artifact, which is used to organize and differentiate artifacts. (default is 'evaluation') |
table_name | (str) The name of the table as will be displayed in the UI. (default is 'eval_data') |
aliases | (List[str]) List of aliases for the prediction table. |
set_model
set_model(
model
)
set_params
set_params(
params
)