Skip to content

Distributive Training for PyTorch

This guide demonstrates multi-node training using Pytorch. It assumes Pytorch is installed and the version is 2.0 or newer. To setup Pytorch, see this link

This could work with an older version of PyTorch, but verify if the utilized PyTorch functions below are implemented in your version.

Create a my_example.py file and import requisite packages

import os 
import torch

# Setting these flags True makes A100 training a lot faster:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True

# For model and ddp
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# For data
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torch.utils.data.distributed import DistributedSampler

# For DDP utils
from socket import gethostname

# For logging
import logging

Create a test model and some functions for getting data

class CNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.fn = nn.Sequential(
            nn.Conv2d(3, 3, (7, 7), padding = 3)
        )

    def forward(self, x):
        x = self.fn(x)
        return x


def get_transform(image_size):
    # Setup data:
    transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True
            ),
        ]
    )
    return transform

def get_cifar10(data_path):
    data = CIFAR10(data_path, train=True, transform=get_transform(32), download=True)
    return data

Create a logger (Optional)

def create_logger(logging_dir):
    """
    Create a logger that writes to a log file and stdout.
    """
    logging.basicConfig(
        level=logging.INFO,
        format="[\033[34m%(asctime)s\033[0m] %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(f"{logging_dir}/log.txt"),
        ],
    )
    logger = logging.getLogger(__name__)
    return logger

Create your main function for the python script, Pay attention to the comments!:

def main(
    rank, # Your node rank
    local_rank, # Your gpu rank 
    args 
):
    assert args.global_batch_size % dist.get_world_size() == 0, "Global batch size must split evenly among ranks."

    # Set your random seed for experiment reproducibility.
    seed = args.global_seed * dist.get_world_size() + rank
    torch.manual_seed(seed)

    if rank == 0:
        experiment_dir = "./"
        logger = create_logger(experiment_dir)
        logger.info(f"Experiment directory created at {experiment_dir}")
        logger.info(f"Batch size per rank: {args.global_batch_size // dist.get_world_size()}")

    # Grab data and create distributive data loader
    dataset = get_cifar10(args.data_path)

    sampler = DistributedSampler(
        dataset,
        num_replicas=dist.get_world_size(),
        rank=rank,
        shuffle=True,
        seed=args.global_seed
    )

    loader = DataLoader(
        dataset,
        batch_size=int(args.global_batch_size // dist.get_world_size()),
        shuffle=False, # IMPORTANT set this to "False" since sampler's shuffle is True
        sampler=sampler,
        num_workers=args.num_workers, # This should be equal to the number of CPUs set per task
        pin_memory=True,
        drop_last=True, # Set "True" to prevent uneven splits
    )

    # Create model
    model = CNN().to(local_rank)
    model = DDP(model, device_ids=[local_rank])
    opt = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(50):
        model.train()
        running_loss = log_steps = 0

        sampler.set_epoch(epoch)
        for x, _ in loader:
            opt.zero_grad()
            x = x.to(local_rank)
            pred = model(x)
            loss = torch.square(pred - x).mean()
            loss.backward()
            opt.step()

            log_steps += 1
            running_loss += loss.item()

        avg_loss = torch.tensor(running_loss / log_steps, device=local_rank)            
        dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
        avg_loss = avg_loss.item() / dist.get_world_size()

        dist.barrier()
        if rank == 0:
            logger.info(f"(step={epoch}), Train Loss: {avg_loss:.5f}")

    dist.barrier() 
    if rank == 0:
        logger.info("Done!")
    dist.destroy_process_group()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--num-workers", type=int, default=8)
    parser.add_argument("--data-path", type=str, required=True)
    parser.add_argument("--global-batch-size", type=int, default=128)
    parser.add_argument("--global-seed", type=int, default=3407)
    args = parser.parse_args()

    rank          = int(os.environ["SLURM_PROCID"])
    world_size    = int(os.environ["WORLD_SIZE"])
    gpus_per_node = int(os.environ["SLURM_GPUS_ON_NODE"])

    assert gpus_per_node == torch.cuda.device_count()
    print(f"Hello from rank {rank} of {world_size} on {gethostname()} where there are" \
          f" {gpus_per_node} allocated GPUs per node.", flush=True)

    local_rank = rank - gpus_per_node * (rank // gpus_per_node)
    torch.cuda.set_device(local_rank) # set torch cuda to your specific GPU to avoid sharing among local ranks
    torch.cuda.empty_cache() # empty your device's cache just in case to prevent OOM

    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    if rank == 0: print(f"Group initialized? {dist.is_initialized()}", flush=True)
    main(rank, local_rank, args)

Building a Slurm script

Create a simple_job.sh script to submit as a job with Slurm.

MAKE SURE YOU CHANGE THE APPROPRIATE VARIABLES: email and data path for this script

#!/bin/bash
#SBATCH --job-name=simple-ddp-ex     # create a short name for your job
#SBATCH --partition=npl-2024     # appropriate partition; if not specified, slurm will automatically do it for you
#SBATCH --nodes=2                # node count
#SBATCH --ntasks-per-node=2      # set this equals to the number of gpus per node
#SBATCH --cpus-per-task=8        # cpu-cores per task (>1 if multi-threaded tasks)
#SBATCH --gres=gpu:2             # number of allocated gpus per node
#SBATCH --time=00:02:00          # total run time limit (HH:MM:SS)
#SBATCH --mail-type=begin        # send email when job begins
#SBATCH --mail-type=end          # send email when job ends
#SBATCH --mail-user=my_example@rpi.edu    # change this to your email!

# export your rank 0 information (its address and port)
export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export WORLD_SIZE=$(($SLURM_NNODES * $SLURM_NTASKS_PER_NODE))
echo "WORLD_SIZE="$WORLD_SIZE
echo "MASTER_ADDR="$MASTER_ADDR

# let 64 be the local batch size; 128 x (2 nodes x 2 gpus) = 512 is the global batch size
srun python my_example.py --data-path ../data --global-batch-size 512 --num-workers 8 

Submit the job

sbatch simple_job.sh

You can check your job via squeue and the slurm log created in your current directory.

Additional Tips and Information

Refer to the official slurm guide for additional information on slurm commands and their usage.

When checkpointing the model during training, be sure to save it when rank == 0. In contrast, when reloading your model, all ranks must reload.

Credit: Bao Pham, phamb@rpi.edu 3/3/24