mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-28 14:24:41 +00:00
more GPU types from PR 147 thanks @Qubitium
This commit is contained in:
parent
2955650327
commit
f5425245f9
|
|
@ -201,51 +201,76 @@ class DummyWandb:
|
|||
def finish(self):
|
||||
pass
|
||||
|
||||
# hardcoded BF16 peak flops for NVIDIA A100, H100, H200, B200 GPU and AMD MI250, MI300X, MI325X, MI355X and Intel PVC
|
||||
# hardcoded BF16 peak flops for various GPUs
|
||||
# inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py
|
||||
# and PR: https://github.com/karpathy/nanochat/pull/147
|
||||
def get_peak_flops(device_name: str) -> float:
|
||||
if "A100" in device_name:
|
||||
# data from https://www.nvidia.com/en-us/data-center/a100/
|
||||
return 312e12
|
||||
elif "H100" in device_name:
|
||||
# data from https://www.nvidia.com/en-us/data-center/h100/
|
||||
# NOTE: Specifications are one-half lower without sparsity.
|
||||
if "NVL" in device_name:
|
||||
return 835e12
|
||||
elif "PCIe" in device_name:
|
||||
return 756e12
|
||||
else: # for H100 SXM and other variants
|
||||
return 989e12
|
||||
elif "H200" in device_name:
|
||||
# data from https://www.nvidia.com/en-us/data-center/h200/
|
||||
return 989e12
|
||||
elif "B200" in device_name:
|
||||
# data from https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703
|
||||
name = device_name.lower()
|
||||
|
||||
# --- NVIDIA Blackwell ---
|
||||
if "gb200" in name or "grace blackwell" in name:
|
||||
return 2.5e15
|
||||
if "b200" in name:
|
||||
return 2.25e15
|
||||
elif "MI355X" in device_name:
|
||||
# MI355X data from https://www.amd.com/en/products/accelerators/instinct/mi350/mi355x.html
|
||||
return 2500e12
|
||||
elif "MI300X" in device_name or "MI325X" in device_name:
|
||||
# MI300X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html
|
||||
# MI325X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi325x.html
|
||||
return 1300e12
|
||||
elif "MI250X" in device_name:
|
||||
# data from https://www.amd.com/en/products/accelerators/instinct/mi200/mi250x.html (per GCD)
|
||||
return 191.5e12
|
||||
elif "Data Center GPU Max 1550" in device_name:
|
||||
# Also known as Ponte Vecchio (PVC).
|
||||
# data from https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
|
||||
# Dot Product Accumulate Systolic (DPAS):
|
||||
# - Freq: 1300MHz
|
||||
# - #ops: 512
|
||||
# Full EU mode (i.e. 512 max compute units): 340.8 TFLOPS (BF16)
|
||||
# Standard EU mode (i.e. 448 max compute units): 298.2 TFLOPS (BF16)
|
||||
if "b100" in name:
|
||||
return 1.8e15
|
||||
|
||||
# --- NVIDIA Hopper (H100/H200/H800) ---
|
||||
if "h200" in name:
|
||||
if "nvl" in name or "pcie" in name:
|
||||
return 836e12
|
||||
return 989e12 # H200 SXM
|
||||
if "h100" in name:
|
||||
if "nvl" in name:
|
||||
return 835e12
|
||||
if "pcie" in name:
|
||||
return 756e12
|
||||
return 989e12 # H100 SXM
|
||||
if "h800" in name:
|
||||
if "nvl" in name:
|
||||
return 989e12
|
||||
return 756e12 # H800 PCIe
|
||||
|
||||
# --- NVIDIA Ampere data center ---
|
||||
if "a100" in name or "a800" in name:
|
||||
return 312e12
|
||||
if "a40" in name:
|
||||
return 149.7e12
|
||||
if "a30" in name:
|
||||
return 165e12
|
||||
|
||||
# --- NVIDIA Ada data center ---
|
||||
if "l40s" in name or "l40-s" in name or "l40 s" in name:
|
||||
return 362e12
|
||||
if "l4" in name:
|
||||
return 121e12
|
||||
|
||||
# --- AMD CDNA accelerators ---
|
||||
if "mi355" in name:
|
||||
return 2.5e15
|
||||
if "mi325" in name or "mi300x" in name:
|
||||
return 1.3074e15
|
||||
if "mi300a" in name:
|
||||
return 980.6e12
|
||||
if "mi250x" in name:
|
||||
return 383e12
|
||||
if "mi250" in name:
|
||||
return 362.1e12
|
||||
|
||||
# --- Intel ---
|
||||
if "data center gpu max 1550" in name:
|
||||
# Ponte Vecchio (PVC) - dynamic based on compute units
|
||||
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
|
||||
return 512 * max_comp_units * 1300 * 10**6
|
||||
elif "l40s" in device_name:
|
||||
# data from: "https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413"
|
||||
return 362e12
|
||||
|
||||
else: # for other GPU types, assume A100
|
||||
logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100")
|
||||
return 312e12
|
||||
# --- Consumer RTX (for hobbyists) ---
|
||||
if "5090" in name:
|
||||
return 209.5e12
|
||||
if "4090" in name:
|
||||
return 165.2e12
|
||||
if "3090" in name:
|
||||
return 71e12
|
||||
|
||||
# Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess
|
||||
logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%")
|
||||
return float('inf')
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user