From 61b7eae7e040cfdf461455a9793e9747a6ad7f6b Mon Sep 17 00:00:00 2001 From: svlandeg Date: Sat, 31 Jan 2026 12:03:43 +0100 Subject: [PATCH] use _PEAK_FLOPS_TABLE instead of if-else structure --- nanochat/common.py | 100 +++++++++++++++++++-------------------------- 1 file changed, 41 insertions(+), 59 deletions(-) diff --git a/nanochat/common.py b/nanochat/common.py index db9e317..9bcd5dd 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -207,70 +207,52 @@ class DummyWandb: def get_peak_flops(device_name: str) -> float: name = device_name.lower() - # --- NVIDIA Blackwell --- - if "gb200" in name or "grace blackwell" in name: - return 2.5e15 - if "b200" in name: - return 2.25e15 - if "b100" in name: - return 1.8e15 - - # --- NVIDIA Hopper (H100/H200/H800) --- - if "h200" in name: - if "nvl" in name or "pcie" in name: - return 836e12 - return 989e12 # H200 SXM - if "h100" in name: - if "nvl" in name: - return 835e12 - if "pcie" in name: - return 756e12 - return 989e12 # H100 SXM - if "h800" in name: - if "nvl" in name: - return 989e12 - return 756e12 # H800 PCIe - - # --- NVIDIA Ampere data center --- - if "a100" in name or "a800" in name: - return 312e12 - if "a40" in name: - return 149.7e12 - if "a30" in name: - return 165e12 - - # --- NVIDIA Ada data center --- - if "l40s" in name or "l40-s" in name or "l40 s" in name: - return 362e12 - if "l4" in name: - return 121e12 - - # --- AMD CDNA accelerators --- - if "mi355" in name: - return 2.5e15 - if "mi325" in name or "mi300x" in name: - return 1.3074e15 - if "mi300a" in name: - return 980.6e12 - if "mi250x" in name: - return 383e12 - if "mi250" in name: - return 362.1e12 - - # --- Intel --- + # Table order matters: more specific patterns first. + _PEAK_FLOPS_TABLE = ( + # NVIDIA Blackwell + (["gb200"], 2.5e15), + (["grace blackwell"], 2.5e15), + (["b200"], 2.25e15), + (["b100"], 1.8e15), + # NVIDIA Hopper + (["h200", "nvl"], 836e12), + (["h200", "pcie"], 836e12), + (["h200"], 989e12), + (["h100", "nvl"], 835e12), + (["h100", "pcie"], 756e12), + (["h100"], 989e12), + (["h800", "nvl"], 989e12), + (["h800"], 756e12), + # NVIDIA Ampere data center + (["a100"], 312e12), + (["a800"], 312e12), + (["a40"], 149.7e12), + (["a30"], 165e12), + # NVIDIA Ada data center + (["l40s"], 362e12), + (["l40-s"], 362e12), + (["l40 s"], 362e12), + (["l4"], 121e12), + # AMD CDNA accelerators + (["mi355"], 2.5e15), + (["mi325"], 1.3074e15), + (["mi300x"], 1.3074e15), + (["mi300a"], 980.6e12), + (["mi250x"], 383e12), + (["mi250"], 362.1e12), + # Consumer RTX + (["5090"], 209.5e12), + (["4090"], 165.2e12), + (["3090"], 71e12), + ) + for patterns, flops in _PEAK_FLOPS_TABLE: + if all(p in name for p in patterns): + return flops if "data center gpu max 1550" in name: # Ponte Vecchio (PVC) - dynamic based on compute units max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units return 512 * max_comp_units * 1300 * 10**6 - # --- Consumer RTX (for hobbyists) --- - if "5090" in name: - return 209.5e12 - if "4090" in name: - return 165.2e12 - if "3090" in name: - return 71e12 - # Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%") return float('inf')