diff --git a/nanochat/common.py b/nanochat/common.py index 8f36f94..9090097 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -162,11 +162,18 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() - if ddp and device_type == "cuda": - device = torch.device("cuda", ddp_local_rank) - torch.cuda.set_device(device) # make "cuda" default to this device - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() + if ddp: + if device_type == "cuda": + device = torch.device("cuda", ddp_local_rank) + torch.cuda.set_device(device) # make "cuda" default to this device + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + elif device_type == "cpu": + device = torch.device("cpu") + dist.init_process_group(backend="gloo") + dist.barrier() + else: + device = torch.device(device_type) # mps else: device = torch.device(device_type) # mps|cpu @@ -177,7 +184,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps def compute_cleanup(): """Companion function to compute_init, to clean things up before script exit""" - if is_ddp(): + if is_ddp() and dist.is_initialized(): dist.destroy_process_group() class DummyWandb: