Merge pull request #23 from LokiMetaSmith/fix-amd-triton-reinstall

Explicitly enable allow_tf32 in nanochat/common.py
This commit is contained in:
Lawrence R Kincheloe III 2025-11-23 02:28:00 -06:00 committed by GitHub
commit e14d7ba6bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -166,6 +166,9 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
# Precision
if device_type == "cuda":
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Explicitly enable allow_tf32 to ensure it's on, helping silence warnings on some platforms
torch.backends.cuda.matmul.allow_tf32 = True
# print0(f"Precision set: float32_matmul_precision=high, allow_tf32={torch.backends.cuda.matmul.allow_tf32}")
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()