remove try catch

This commit is contained in:
Ozamatash 2025-10-19 20:35:42 -05:00 committed by GitHub
parent ebd9c6171a
commit 5158f00a8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -150,26 +150,21 @@ def get_peak_flops(device=None) -> float:
"L40S": 362e12,
}
try:
if device is None:
device = torch.device("cuda:0")
device_name = torch.cuda.get_device_name(device)
# Match GPU by substring (case-insensitive). Sort by length to check specific names first
# e.g., "6000 Blackwell Max-Q" checked before "6000 Blackwell"
for gpu_key, flops in sorted(GPU_PEAK_FLOPS.items(), key=lambda x: len(x[0]), reverse=True):
if gpu_key.lower() in device_name.lower():
if int(os.environ.get('RANK', 0)) == 0:
logger.info(f"Detected GPU: {device_name} -> using {flops/1e12:.1f} TFLOPS (BF16 Tensor)")
return flops
# Unknown GPU: warn and default to H100
logger.warning(f"Unknown GPU '{device_name}', defaulting to H100 peak FLOPS (989e12)")
return 989e12
except Exception as e:
logger.warning(f"Could not detect GPU, defaulting to H100: {e}")
return 989e12
if device is None:
device = torch.device("cuda:0")
device_name = torch.cuda.get_device_name(device)
# Match GPU by substring (case-insensitive). Sort by length to check specific names first
# e.g., "6000 Blackwell Max-Q" checked before "6000 Blackwell"
for gpu_key, flops in sorted(GPU_PEAK_FLOPS.items(), key=lambda x: len(x[0]), reverse=True):
if gpu_key.lower() in device_name.lower():
if int(os.environ.get('RANK', 0)) == 0:
logger.info(f"Detected GPU: {device_name} -> using {flops/1e12:.1f} TFLOPS (BF16 Tensor)")
return flops
# Unknown GPU: warn and default to H100
logger.warning(f"Unknown GPU '{device_name}', defaulting to H100 peak FLOPS (989e12)")
return 989e12
class DummyWandb:
"""Useful if we wish to not use wandb but have all the same signatures"""