Documentation
中文 (简体)
搜索文档…
⌃K

JAX Example

这是一个训练MLP并保存到W&B的完整JAX代码示例。
你可以在GitHub 上找到这个示例,并在W&B上查看结果。
import time
import itertools
import numpy.random as npr
import wandb
import jax.numpy as np
from jax.config import config
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax
import datasets
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -np.mean(preds * targets)
def accuracy(params, batch):
inputs, targets = batch
target_class = np.argmax(targets, axis=1)
predicted_class = np.argmax(predict(params, inputs), axis=1)
return np.mean(predicted_class == target_class)
init_random_params, predict = stax.serial(
Dense(1024), Relu,
Dense(1024), Relu,
Dense(10), LogSoftmax)
if __name__ == "__main__":
wandb.init()
rng = random.PRNGKey(0)
wandb.config.step_size = 0.001
wandb.config.num_epochs = 10
wandb.config.batch_size = 128
wandb.config.momentum_mass = 0.9
train_images, train_labels, test_images, test_labels = datasets.mnist()
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, wandb.config.batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * wandb.config.batch_size:(i + 1) * wandb.config.batch_size]
yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
opt_init, opt_update, get_params = optimizers.momentum(wandb.config.step_size, mass=wandb.config.momentum_mass)
@jit
def update(i, opt_state, batch):
params = get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()
print("\nStarting training...")
for epoch in range(wandb.config.num_epochs):
start_time = time.time()
for _ in range(num_batches):
opt_state = update(next(itercount), opt_state, next(batches))
epoch_time = time.time() - start_time
params = get_params(opt_state)
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
wandb.log({"Train Accuracy": float(train_acc), "Test Accuracy": float(test_acc)})