from composer import Callback, State, Logger
class LogPredictions(Callback):
def __init__(self, num_samples=100, seed=1234):
self.num_samples = num_samples
def eval_batch_end(self, state: State, logger: Logger):
"""Compute predictions per batch and stores them on self.data"""
if state.timer.epoch == state.max_duration: #on last val epoch
if len(self.data) < self.num_samples:
outputs = state.outputs.argmax(-1)
data = [[wandb.Image(x_i), y_i, y_pred] for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs[:n]))]
def eval_end(self, state: State, logger: Logger):
"Create a wandb.Table and logs it"
columns = ['image', 'ground truth', 'prediction']
table = wandb.Table(columns=columns, data=self.data[:self.num_samples])
wandb.log({'sample_table':table}, step=int(state.timer.batch))
callbacks=[LogPredictions()]