Fix process group initialization for CPU DDP and improve cleanup safety

This commit is contained in:
google-labs-jules[bot] 2025-11-21 23:41:34 +00:00
parent 104308cf78
commit 9235fe4000

View File

@ -162,11 +162,18 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp and device_type == "cuda": if ddp:
if device_type == "cuda":
device = torch.device("cuda", ddp_local_rank) device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device) dist.init_process_group(backend="nccl", device_id=device)
dist.barrier() dist.barrier()
elif device_type == "cpu":
device = torch.device("cpu")
dist.init_process_group(backend="gloo")
dist.barrier()
else:
device = torch.device(device_type) # mps
else: else:
device = torch.device(device_type) # mps|cpu device = torch.device(device_type) # mps|cpu
@ -177,7 +184,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
def compute_cleanup(): def compute_cleanup():
"""Companion function to compute_init, to clean things up before script exit""" """Companion function to compute_init, to clean things up before script exit"""
if is_ddp(): if is_ddp() and dist.is_initialized():
dist.destroy_process_group() dist.destroy_process_group()
class DummyWandb: class DummyWandb: