diff --git a/nanochat/common.py b/nanochat/common.py index ff92f9a..6445203 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -150,26 +150,21 @@ def get_peak_flops(device=None) -> float: "L40S": 362e12, } - try: - if device is None: - device = torch.device("cuda:0") - device_name = torch.cuda.get_device_name(device) - - # Match GPU by substring (case-insensitive). Sort by length to check specific names first - # e.g., "6000 Blackwell Max-Q" checked before "6000 Blackwell" - for gpu_key, flops in sorted(GPU_PEAK_FLOPS.items(), key=lambda x: len(x[0]), reverse=True): - if gpu_key.lower() in device_name.lower(): - if int(os.environ.get('RANK', 0)) == 0: - logger.info(f"Detected GPU: {device_name} -> using {flops/1e12:.1f} TFLOPS (BF16 Tensor)") - return flops - - # Unknown GPU: warn and default to H100 - logger.warning(f"Unknown GPU '{device_name}', defaulting to H100 peak FLOPS (989e12)") - return 989e12 - - except Exception as e: - logger.warning(f"Could not detect GPU, defaulting to H100: {e}") - return 989e12 + if device is None: + device = torch.device("cuda:0") + device_name = torch.cuda.get_device_name(device) + + # Match GPU by substring (case-insensitive). Sort by length to check specific names first + # e.g., "6000 Blackwell Max-Q" checked before "6000 Blackwell" + for gpu_key, flops in sorted(GPU_PEAK_FLOPS.items(), key=lambda x: len(x[0]), reverse=True): + if gpu_key.lower() in device_name.lower(): + if int(os.environ.get('RANK', 0)) == 0: + logger.info(f"Detected GPU: {device_name} -> using {flops/1e12:.1f} TFLOPS (BF16 Tensor)") + return flops + + # Unknown GPU: warn and default to H100 + logger.warning(f"Unknown GPU '{device_name}', defaulting to H100 peak FLOPS (989e12)") + return 989e12 class DummyWandb: """Useful if we wish to not use wandb but have all the same signatures"""