Documentation
搜索文档…
JAX Example
这是一个训练MLP并保存到W&B的完整JAX代码示例。
你可以在GitHub 上找到这个示例,并在W&B上查看结果。
1
import time
2
import itertools
3
4
import numpy.random as npr
5
import wandb
6
7
import jax.numpy as np
8
from jax.config import config
9
from jax import jit, grad, random
10
from jax.experimental import optimizers
11
from jax.experimental import stax
12
from jax.experimental.stax import Dense, Relu, LogSoftmax
13
import datasets
14
15
16
def loss(params, batch):
17
inputs, targets = batch
18
preds = predict(params, inputs)
19
return -np.mean(preds * targets)
20
21
def accuracy(params, batch):
22
inputs, targets = batch
23
target_class = np.argmax(targets, axis=1)
24
predicted_class = np.argmax(predict(params, inputs), axis=1)
25
return np.mean(predicted_class == target_class)
26
27
init_random_params, predict = stax.serial(
28
Dense(1024), Relu,
29
Dense(1024), Relu,
30
Dense(10), LogSoftmax)
31
32
if __name__ == "__main__":
33
wandb.init()
34
rng = random.PRNGKey(0)
35
36
wandb.config.step_size = 0.001
37
wandb.config.num_epochs = 10
38
wandb.config.batch_size = 128
39
wandb.config.momentum_mass = 0.9
40
41
42
43
train_images, train_labels, test_images, test_labels = datasets.mnist()
44
num_train = train_images.shape[0]
45
num_complete_batches, leftover = divmod(num_train, wandb.config.batch_size)
46
num_batches = num_complete_batches + bool(leftover)
47
48
def data_stream():
49
rng = npr.RandomState(0)
50
while True:
51
perm = rng.permutation(num_train)
52
for i in range(num_batches):
53
batch_idx = perm[i * wandb.config.batch_size:(i + 1) * wandb.config.batch_size]
54
yield train_images[batch_idx], train_labels[batch_idx]
55
batches = data_stream()
56
57
opt_init, opt_update, get_params = optimizers.momentum(wandb.config.step_size, mass=wandb.config.momentum_mass)
58
59
@jit
60
def update(i, opt_state, batch):
61
params = get_params(opt_state)
62
return opt_update(i, grad(loss)(params, batch), opt_state)
63
64
_, init_params = init_random_params(rng, (-1, 28 * 28))
65
opt_state = opt_init(init_params)
66
itercount = itertools.count()
67
68
print("\nStarting training...")
69
for epoch in range(wandb.config.num_epochs):
70
start_time = time.time()
71
for _ in range(num_batches):
72
opt_state = update(next(itercount), opt_state, next(batches))
73
epoch_time = time.time() - start_time
74
75
params = get_params(opt_state)
76
train_acc = accuracy(params, (train_images, train_labels))
77
test_acc = accuracy(params, (test_images, test_labels))
78
wandb.log({"Train Accuracy": float(train_acc), "Test Accuracy": float(test_acc)})
Copied!
最近更新 9mo ago
复制链接