mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-06 02:29:53 +00:00
Merge 61b7eae7e0 into 1ddaad1c1c
This commit is contained in:
commit
cb8200a16b
|
|
@ -207,70 +207,52 @@ class DummyWandb:
|
|||
def get_peak_flops(device_name: str) -> float:
|
||||
name = device_name.lower()
|
||||
|
||||
# --- NVIDIA Blackwell ---
|
||||
if "gb200" in name or "grace blackwell" in name:
|
||||
return 2.5e15
|
||||
if "b200" in name:
|
||||
return 2.25e15
|
||||
if "b100" in name:
|
||||
return 1.8e15
|
||||
|
||||
# --- NVIDIA Hopper (H100/H200/H800) ---
|
||||
if "h200" in name:
|
||||
if "nvl" in name or "pcie" in name:
|
||||
return 836e12
|
||||
return 989e12 # H200 SXM
|
||||
if "h100" in name:
|
||||
if "nvl" in name:
|
||||
return 835e12
|
||||
if "pcie" in name:
|
||||
return 756e12
|
||||
return 989e12 # H100 SXM
|
||||
if "h800" in name:
|
||||
if "nvl" in name:
|
||||
return 989e12
|
||||
return 756e12 # H800 PCIe
|
||||
|
||||
# --- NVIDIA Ampere data center ---
|
||||
if "a100" in name or "a800" in name:
|
||||
return 312e12
|
||||
if "a40" in name:
|
||||
return 149.7e12
|
||||
if "a30" in name:
|
||||
return 165e12
|
||||
|
||||
# --- NVIDIA Ada data center ---
|
||||
if "l40s" in name or "l40-s" in name or "l40 s" in name:
|
||||
return 362e12
|
||||
if "l4" in name:
|
||||
return 121e12
|
||||
|
||||
# --- AMD CDNA accelerators ---
|
||||
if "mi355" in name:
|
||||
return 2.5e15
|
||||
if "mi325" in name or "mi300x" in name:
|
||||
return 1.3074e15
|
||||
if "mi300a" in name:
|
||||
return 980.6e12
|
||||
if "mi250x" in name:
|
||||
return 383e12
|
||||
if "mi250" in name:
|
||||
return 362.1e12
|
||||
|
||||
# --- Intel ---
|
||||
# Table order matters: more specific patterns first.
|
||||
_PEAK_FLOPS_TABLE = (
|
||||
# NVIDIA Blackwell
|
||||
(["gb200"], 2.5e15),
|
||||
(["grace blackwell"], 2.5e15),
|
||||
(["b200"], 2.25e15),
|
||||
(["b100"], 1.8e15),
|
||||
# NVIDIA Hopper
|
||||
(["h200", "nvl"], 836e12),
|
||||
(["h200", "pcie"], 836e12),
|
||||
(["h200"], 989e12),
|
||||
(["h100", "nvl"], 835e12),
|
||||
(["h100", "pcie"], 756e12),
|
||||
(["h100"], 989e12),
|
||||
(["h800", "nvl"], 989e12),
|
||||
(["h800"], 756e12),
|
||||
# NVIDIA Ampere data center
|
||||
(["a100"], 312e12),
|
||||
(["a800"], 312e12),
|
||||
(["a40"], 149.7e12),
|
||||
(["a30"], 165e12),
|
||||
# NVIDIA Ada data center
|
||||
(["l40s"], 362e12),
|
||||
(["l40-s"], 362e12),
|
||||
(["l40 s"], 362e12),
|
||||
(["l4"], 121e12),
|
||||
# AMD CDNA accelerators
|
||||
(["mi355"], 2.5e15),
|
||||
(["mi325"], 1.3074e15),
|
||||
(["mi300x"], 1.3074e15),
|
||||
(["mi300a"], 980.6e12),
|
||||
(["mi250x"], 383e12),
|
||||
(["mi250"], 362.1e12),
|
||||
# Consumer RTX
|
||||
(["5090"], 209.5e12),
|
||||
(["4090"], 165.2e12),
|
||||
(["3090"], 71e12),
|
||||
)
|
||||
for patterns, flops in _PEAK_FLOPS_TABLE:
|
||||
if all(p in name for p in patterns):
|
||||
return flops
|
||||
if "data center gpu max 1550" in name:
|
||||
# Ponte Vecchio (PVC) - dynamic based on compute units
|
||||
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
|
||||
return 512 * max_comp_units * 1300 * 10**6
|
||||
|
||||
# --- Consumer RTX (for hobbyists) ---
|
||||
if "5090" in name:
|
||||
return 209.5e12
|
||||
if "4090" in name:
|
||||
return 165.2e12
|
||||
if "3090" in name:
|
||||
return 71e12
|
||||
|
||||
# Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess
|
||||
logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%")
|
||||
return float('inf')
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user