This commit is contained in:
Dipesh Babu 2025-12-05 17:38:03 -06:00 committed by GitHub
commit e3393045a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -113,12 +113,24 @@ def print_banner():
""" """
print0(banner) print0(banner)
def is_ddp(): def is_ddp_requested() -> bool:
# TODO is there a proper way """
return int(os.environ.get('RANK', -1)) != -1 True if launched by torchrun (env present), even before init.
Used to decide whether we *should* initialize a PG.
"""
return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE"))
def is_ddp_initialized() -> bool:
"""
True if torch.distributed is available and the process group is initialized.
Used at cleanup to avoid destroying a non-existent PG.
"""
return dist.is_available() and dist.is_initialized()
def get_dist_info(): def get_dist_info():
if is_ddp(): if is_ddp_requested():
# We rely on torchrun's env to decide if we SHOULD init.
# (Initialization itself happens in compute init.)
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
ddp_rank = int(os.environ['RANK']) ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK']) ddp_local_rank = int(os.environ['LOCAL_RANK'])
@ -161,8 +173,8 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp and device_type == "cuda": if is_ddp_requested and device_type == "cuda":
device = torch.device("cuda", ddp_local_rank) device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device) dist.init_process_group(backend="nccl", device_id=device)
@ -173,11 +185,11 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
if ddp_rank == 0: if ddp_rank == 0:
logger.info(f"Distributed world size: {ddp_world_size}") logger.info(f"Distributed world size: {ddp_world_size}")
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device
def compute_cleanup(): def compute_cleanup():
"""Companion function to compute_init, to clean things up before script exit""" """Companion function to compute_init, to clean things up before script exit"""
if is_ddp(): if is_ddp_initialized():
dist.destroy_process_group() dist.destroy_process_group()
class DummyWandb: class DummyWandb: