FEAT: Allow CPU-only execution in compute_init

Modified compute_init in nanochat/common.py to allow for CPU-only execution by checking the NANOCHAT_DEVICE environment variable. This enhances flexibility and hackability, aligning with the project's goals of accessibility.
This commit is contained in:
SyedaAnshrahGillani 2025-10-22 18:07:30 +05:00
parent b70da6d907
commit 6641aeed1d

View File

@ -92,19 +92,24 @@ def get_dist_info():
def compute_init():
"""Basic initialization that we keep doing over and over, so make common."""
# CUDA is currently required
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
# Determine device
device_type = os.environ.get("NANOCHAT_DEVICE", "cuda")
if device_type == "cuda":
assert torch.cuda.is_available(), "CUDA is not available. Set NANOCHAT_DEVICE=cpu to run on CPU."
device = torch.device(device_type)
# Reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed(42)
if device_type == "cuda":
torch.cuda.manual_seed(42)
# skipping full reproducibility for now, possibly investigate slowdown later
# torch.use_deterministic_algorithms(True)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# Precision
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
if device_type == "cuda":
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()