diff --git a/nanochat/common.py b/nanochat/common.py index 8f36f94..6294aa7 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -113,12 +113,24 @@ def print_banner(): """ print0(banner) -def is_ddp(): - # TODO is there a proper way - return int(os.environ.get('RANK', -1)) != -1 +def is_ddp_requested() -> bool: + """ + True if launched by torchrun (env present), even before init. + Used to decide whether we *should* initialize a PG. + """ + return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE")) + +def is_ddp_initialized() -> bool: + """ + True if torch.distributed is available and the process group is initialized. + Used at cleanup to avoid destroying a non-existent PG. + """ + return dist.is_available() and dist.is_initialized() def get_dist_info(): - if is_ddp(): + if is_ddp_requested(): + # We rely on torchrun's env to decide if we SHOULD init. + # (Initialization itself happens in compute init.) assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) ddp_rank = int(os.environ['RANK']) ddp_local_rank = int(os.environ['LOCAL_RANK']) @@ -161,8 +173,8 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA - ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() - if ddp and device_type == "cuda": + is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() + if is_ddp_requested and device_type == "cuda": device = torch.device("cuda", ddp_local_rank) torch.cuda.set_device(device) # make "cuda" default to this device dist.init_process_group(backend="nccl", device_id=device) @@ -173,11 +185,11 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps if ddp_rank == 0: logger.info(f"Distributed world size: {ddp_world_size}") - return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device + return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device def compute_cleanup(): """Companion function to compute_init, to clean things up before script exit""" - if is_ddp(): + if is_ddp_initialized(): dist.destroy_process_group() class DummyWandb: