Merge pull request #10 from LokiMetaSmith/fix-cpu-ddp-init

Fix process group initialization for CPU DDP and improve cleanup safety
This commit is contained in:
Lawrence R Kincheloe III 2025-11-21 17:42:06 -06:00 committed by GitHub
commit b5fd54ac1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -162,11 +162,18 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp and device_type == "cuda":
device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
if ddp:
if device_type == "cuda":
device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
elif device_type == "cpu":
device = torch.device("cpu")
dist.init_process_group(backend="gloo")
dist.barrier()
else:
device = torch.device(device_type) # mps
else:
device = torch.device(device_type) # mps|cpu
@ -177,7 +184,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
def compute_cleanup():
"""Companion function to compute_init, to clean things up before script exit"""
if is_ddp():
if is_ddp() and dist.is_initialized():
dist.destroy_process_group()
class DummyWandb: