From cdb5a455eee47f79dd476c4a496b79ebefc2ddd3 Mon Sep 17 00:00:00 2001 From: Sermet Pekin <96650846+SermetPekin@users.noreply.github.com> Date: Mon, 20 Oct 2025 11:43:47 +0300 Subject: [PATCH] fallback to cpu on compute_init function fallback to cpu on compute_init function --- nanochat/common.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/nanochat/common.py b/nanochat/common.py index 8b10df9..3cbd6b0 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -91,34 +91,26 @@ def get_dist_info(): def compute_init(): """Basic initialization that we keep doing over and over, so make common.""" - - # CUDA is currently required - assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm" - - # Reproducibility - torch.manual_seed(42) - torch.cuda.manual_seed(42) - # skipping full reproducibility for now, possibly investigate slowdown later - # torch.use_deterministic_algorithms(True) - # torch.backends.cudnn.deterministic = True - # torch.backends.cudnn.benchmark = False - + # Check if CUDA is available, otherwise fall back to CPU + if torch.cuda.is_available(): + device = torch.device("cuda") + torch.manual_seed(42) + torch.cuda.manual_seed(42) + else: + device = torch.device("cpu") + torch.manual_seed(42) + logger.warning("CUDA is not available. Falling back to CPU.") # Precision torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls - # Distributed setup: Distributed Data Parallel (DDP), optional ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() - if ddp: + if ddp and torch.cuda.is_available(): device = torch.device("cuda", ddp_local_rank) torch.cuda.set_device(device) # make "cuda" default to this device dist.init_process_group(backend="nccl", device_id=device) dist.barrier() - else: - device = torch.device("cuda") - if ddp_rank == 0: logger.info(f"Distributed world size: {ddp_world_size}") - return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device def compute_cleanup():