This commit is contained in:
Dipesh Babu 2025-11-13 11:50:09 -05:00 committed by GitHub
commit 3ad83ecbc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -113,12 +113,24 @@ def print_banner():
"""
print0(banner)
def is_ddp():
# TODO is there a proper way
return int(os.environ.get('RANK', -1)) != -1
def is_ddp_requested() -> bool:
"""
True if launched by torchrun (env present), even before init.
Used to decide whether we *should* initialize a PG.
"""
return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE"))
def is_ddp_initialized() -> bool:
"""
True if torch.distributed is available and the process group is initialized.
Used at cleanup to avoid destroying a non-existent PG.
"""
return dist.is_available() and dist.is_initialized()
def get_dist_info():
if is_ddp():
if is_ddp_requested():
# We rely on torchrun's env to decide if we SHOULD init.
# (Initialization itself happens in compute init.)
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
@ -161,8 +173,8 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp and device_type == "cuda":
is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if is_ddp_requested and device_type == "cuda":
device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device)
@ -173,11 +185,11 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
if ddp_rank == 0:
logger.info(f"Distributed world size: {ddp_world_size}")
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device
def compute_cleanup():
"""Companion function to compute_init, to clean things up before script exit"""
if is_ddp():
if is_ddp_initialized():
dist.destroy_process_group()
class DummyWandb: