In distributed training, models are trained using multiple GPUs in parallel. To track distributed training using Weights & Biases, here are two patterns we support:
One Process: Only call
wandb.log() from a single process, e.g. the rank0 process. This is the most common solution for logging with PyTorch DDP. In some cases, users funnel data over from other processes using a multiprocessing queue (or another communication primitive) to the main logging process.
All Processes: In every process, call
wandb.init(). These are effectively separate experiments, so use the
group parameter to set a shared experiment name and group the logged values together in the UI.
Below, you'll find a more thorough description of these two patterns, based on a code example from our repository of examples. Check out the "Common Issues" section at the bottom of this guide for some gotchas.
Sometimes a single GPU is insufficient for training large deep learning models on huge amounts of data, so we distribute our training runs across multiple GPUs. PyTorch DDP (
torch.nn) is a popular library for distributed training. In this walkthrough, we'll show how to track metrics with Weights & Biases using PyTorch DDP on two GPUs on a single machine. The basic principles apply to any distributed training setup, but the details of implementation may differ.
In multi-GPU training, the
rank0 process is the main process and coordinates the other processes. Often, it's useful to just track this single process as a W&B run, using
wandb.init() in just the
rank0 process and only calling
wandb.log() there, not in any sub-processes.
This method is simple and robust, but it means that model metrics from other processes (e.g. loss values or inputs from their batches) are not logged. System metrics, like usage and memory, are still logged for all GPUs, since that information is available to all processes.
In our example of this method, we launch multiple processes with
torch.distributed.launch. With this module, we can determine the rank of the process from the
--local_rank argument. Now that we have the rank of the process, we can set up
wandb logging conditionally in the
if __name__ == "__main__":# Get argsargs = parse_args()if args.local_rank == 0: # only on main process# Initialize wandb runrun = wandb.init(entity=args.entity,project=args.project,)# Train model with DDPtrain(args, run)else:train(args)
If you want to see what the outputs look like for this method, check out an example dashboard here. There, you can see that system metrics, like temperature and utilization, were tracked for both GPUs.
The epoch-wise and batch-wise loss values, however, are only logged from a single GPU.
In this method, we track each process in the job, calling
wandb.log() from each process separately. It's also useful to call
wandb.finish() at the end of training, to mark that the run has completed so that all processes exit properly.
The benefit of this method is that more information is accessible for logging and that logging doesn't need to be made conditional on process rank in the code. However, it results in information from a single experiment being reported from multiple runs in the W&B UI.
In order to keep track of which runs correspond to which experiments, we use the grouping feature of Weights & Biases. It's as simple as setting the
group parameter in
wandb.init(). These results will be shown together on a group page in the W&B UI, so our experiments stay organized.
if __name__ == "__main__":# Get argsargs = parse_args()# Initialize runrun = wandb.init(entity=args.entity,project=args.project,group="DDP", # all runs for the experiment in one group)# Train model with DDPtrain(args, run)
If you want to see what the outputs look like for this method, check out an example dashboard here. You'll see two runs grouped together in the sidebar. You can click on this group to get to the dedicated group page for the experiment, which displays metrics from each process separately.
If launching the
wandb process hangs, it could be because the
wandb multiprocessing is interfering with the multiprocessing from distributed training. Try setting the
WANDB_START_METHOD environment variable to
"thread" to use multithreading instead.
Is your process hanging at the end of training? The
wandb process might not know it needs to exit, and that will cause your job to hang. In this case, call
wandb.finish() at the end of your script to mark the run as finished and cause
wandb to exit.