Keras models
Use Weights & Biases for machine learning experiment tracking, dataset versioning, and project collaboration.
This Colab notebook introduces the WandbModelCheckpoint
callback. Use this callback to log your model checkpoints to Weight and Biases Artifacts.
๐ด Setup and Installationโ
First, let us install the latest version of Weights and Biases. We will then authenticate this colab instance to use W&B.
!pip install -qq -U wandb
import os
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import models
import tensorflow_datasets as tfds
# Weights and Biases related imports
import wandb
from wandb.integration.keras import WandbMetricsLogger
from wandb.integration.keras import WandbModelCheckpoint
If this is your first time using W&B or you are not logged in, the link that appears after running wandb.login()
will take you to sign-up/login page. Signing up for a free account is as easy as a few clicks.
wandb.login()
๐ณ Hyperparametersโ
Use of proper config system is a recommended best practice for reproducible machine learning. We can track the hyperparameters for every experiment using W&B. In this colab we will be using simple Python dict
as our config system.
configs = dict(
num_classes = 10,
shuffle_buffer = 1024,
batch_size = 64,
image_size = 28,
image_channels = 1,
earlystopping_patience = 3,
learning_rate = 1e-3,
epochs = 10
)
๐ Datasetโ
In this colab, we will be using CIFAR100 dataset from TensorFlow Dataset catalog. We aim to build a simple image classification pipeline using TensorFlow/Keras.
train_ds, valid_ds = tfds.load('fashion_mnist', split=['train', 'test'])
AUTOTUNE = tf.data.AUTOTUNE
def parse_data(example):
# Get image
image = example["image"]
# image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# Get label
label = example["label"]
label = tf.one_hot(label, depth=configs["num_classes"])
return image, label
def get_dataloader(ds, configs, dataloader_type="train"):
dataloader = ds.map(parse_data, num_parallel_calls=AUTOTUNE)
if dataloader_type=="train":
dataloader = dataloader.shuffle(configs["shuffle_buffer"])
dataloader = (
dataloader
.batch(configs["batch_size"])
.prefetch(AUTOTUNE)
)
return dataloader
trainloader = get_dataloader(train_ds, configs)
validloader = get_dataloader(valid_ds, configs, dataloader_type="valid")
๐ Modelโ
def get_model(configs):
backbone = tf.keras.applications.mobilenet_v2.MobileNetV2(weights='imagenet', include_top=False)
backbone.trainable = False
inputs = layers.Input(shape=(configs["image_size"], configs["image_size"], configs["image_channels"]))
resize = layers.Resizing(32, 32)(inputs)
neck = layers.Conv2D(3, (3,3), padding="same")(resize)
preprocess_input = tf.keras.applications.mobilenet.preprocess_input(neck)
x = backbone(preprocess_input)
x = layers.GlobalAveragePooling2D()(x)
outputs = layers.Dense(configs["num_classes"], activation="softmax")(x)
return models.Model(inputs=inputs, outputs=outputs)
tf.keras.backend.clear_session()
model = get_model(configs)
model.summary()
๐ฟ Compile Modelโ
model.compile(
optimizer = "adam",
loss = "categorical_crossentropy",
metrics = ["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top@5_accuracy')]
)
๐ป Trainโ
# Initialize a W&B run
run = wandb.init(
project = "intro-keras",
config = configs
)
# Train your model
model.fit(
trainloader,
epochs = configs["epochs"],
validation_data = validloader,
callbacks = [
WandbMetricsLogger(log_freq=10),
WandbModelCheckpoint(filepath="models/") # Notice the use of WandbModelCheckpoint here
]
)
# Close the W&B run
run.finish()