Documentation
Search…
Stable Baselines 3
Stable Baselines 3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch. W&B's SB3 integration will:
  • Record metrics such as losses and episodic returns
  • Upload videos of agents playing the games
  • Save the trained model
  • Log model's hyperparameters
  • Log model gradient histograms
Here is an example of a SB3 training run with W&B

Log your SB3 Experiments in 2 lines of code

1
from wandb.integration.sb3 import WandbCallback
2
3
model.learn(..., callback=WandbCallback())
Copied!

WandbCallback Arguments

Argument
Usage
verbose
The verbosity of sb3 output
model_save_path
Path to the folder where the model will be saved, The default value is `None` so the model is not logged
model_save_freq
Frequency to save the model
gradient_save_freq
Frequency to log gradient. The default value is 0 so the gradients are not logged

Basic Example

The W&B SB3 integration uses the logs output from TensorBoard to log your metrics
1
import gym
2
from stable_baselines3 import PPO
3
from stable_baselines3.common.monitor import Monitor
4
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
5
import wandb
6
from wandb.integration.sb3 import WandbCallback
7
8
9
config = {
10
"policy_type": "MlpPolicy",
11
"total_timesteps": 25000,
12
"env_name": "CartPole-v1",
13
}
14
run = wandb.init(
15
project="sb3",
16
config=config,
17
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
18
monitor_gym=True, # auto-upload the videos of agents playing the game
19
save_code=True, # optional
20
)
21
22
23
def make_env():
24
env = gym.make(config["env_name"])
25
env = Monitor(env) # record stats such as returns
26
return env
27
28
29
env = DummyVecEnv([make_env])
30
env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % 2000 == 0, video_length=200)
31
model = PPO(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
32
model.learn(
33
total_timesteps=config["total_timesteps"],
34
callback=WandbCallback(
35
gradient_save_freq=100,
36
model_save_path=f"models/{run.id}",
37
verbose=2,
38
),
39
)
40
run.finish()
Copied!
Last modified 3mo ago