This is a complete example of JAX code that trains a MLP and saves to W&B.
You can find this example on GitHub and see the results on W&B.
import timeimport itertoolsimport numpy.random as nprimport wandbimport jax.numpy as npfrom jax.config import configfrom jax import jit, grad, randomfrom jax.experimental import optimizersfrom jax.experimental import staxfrom jax.experimental.stax import Dense, Relu, LogSoftmaximport datasetsdef loss(params, batch):inputs, targets = batchpreds = predict(params, inputs)return -np.mean(preds * targets)def accuracy(params, batch):inputs, targets = batchtarget_class = np.argmax(targets, axis=1)predicted_class = np.argmax(predict(params, inputs), axis=1)return np.mean(predicted_class == target_class)init_random_params, predict = stax.serial(Dense(1024), Relu,Dense(1024), Relu,Dense(10), LogSoftmax)if __name__ == "__main__":wandb.init()rng = random.PRNGKey(0)wandb.config.step_size = 0.001wandb.config.num_epochs = 10wandb.config.batch_size = 128wandb.config.momentum_mass = 0.9train_images, train_labels, test_images, test_labels = datasets.mnist()num_train = train_images.shape[0]num_complete_batches, leftover = divmod(num_train, wandb.config.batch_size)num_batches = num_complete_batches + bool(leftover)def data_stream():rng = npr.RandomState(0)while True:perm = rng.permutation(num_train)for i in range(num_batches):batch_idx = perm[i * wandb.config.batch_size:(i + 1) * wandb.config.batch_size]yield train_images[batch_idx], train_labels[batch_idx]batches = data_stream()opt_init, opt_update, get_params = optimizers.momentum(wandb.config.step_size, mass=wandb.config.momentum_mass)@jitdef update(i, opt_state, batch):params = get_params(opt_state)return opt_update(i, grad(loss)(params, batch), opt_state)_, init_params = init_random_params(rng, (-1, 28 * 28))opt_state = opt_init(init_params)itercount = itertools.count()print("\nStarting training...")for epoch in range(wandb.config.num_epochs):start_time = time.time()for _ in range(num_batches):opt_state = update(next(itercount), opt_state, next(batches))epoch_time = time.time() - start_timeparams = get_params(opt_state)train_acc = accuracy(params, (train_images, train_labels))test_acc = accuracy(params, (test_images, test_labels))wandb.log({"Train Accuracy": float(train_acc), "Test Accuracy": float(test_acc)})