fix mfu statically keyed to h100 max tflops

This commit is contained in:
Qubitium 2025-10-21 05:08:49 +00:00
parent 0f007889dd
commit ff0605c372
3 changed files with 108 additions and 7 deletions

92
nanochat/mfu.py Normal file
View File

@ -0,0 +1,92 @@
import torch
def get_promised_flops_per_gpu():
"""
Return best-effort dense BF16 Tensor/Matrix peak FLOPs for the active GPU.
Returns:
tuple[str, float, bool]: (device_name, flops_per_gpu, is_estimated)
device_name is the CUDA-reported name for the active device.
flops_per_gpu is the per-device BF16 dense peak in FLOPs.
is_estimated is True when we fall back to heuristic defaults.
"""
device_index = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device_index)
name = props.name.lower()
device_name = props.name
def result(flops, estimated):
return device_name, flops, estimated
def has(*keywords: str) -> bool:
return any(k in name for k in keywords)
# --- NVIDIA Blackwell ---
if has("gb200", "grace blackwell", "b200"):
return result(2.25e15, False) # B200 dense BF16 ≈ 2.25 PFLOPS. (4.5 PFLOPS sparse)
if has("b100"):
return result(1.8e15, False) # B100 dense BF16 ≈ 1.8 PFLOPS. (3.5 PFLOPS sparse)
# --- NVIDIA Hopper (H100/H200/H800) ---
if has("h200"):
# SXM ~= H100 SXM for compute; NVL/PCIe is lower clock
if has("nvl", "pcie"):
return result(836e12, False) # H200 NVL/PCIe dense BF16 ≈ 836 TFLOPS
return result(989e12, False) # H200 SXM dense BF16 ≈ 989 TFLOPS
if has("h100"):
if has("nvl"): # H100 NVL (per GPU)
return result(835e12, False) # 1671 TFLOPS (sparse) ⇒ ~835 dense
if has("pcie"):
return result(756e12, False) # H100 PCIe dense BF16 ≈ 756 TFLOPS
return result(989e12, False) # H100 SXM dense BF16 ≈ 989 TFLOPS
if has("h800"):
# China-optimized Hopper; NVLink configs often quoted at H100-SXM-like numbers
if has("nvl"):
return result(989e12, False) # H800 NVLink dense BF16 ≈ 989 TFLOPS (vendor configs)
return result(756e12, False) # H800 PCIe dense BF16 ≈ 756 TFLOPS
# --- NVIDIA Ampere data center / export variants ---
if has("a100", "pg506"):
return result(312e12, False) # A100 SXM dense BF16 = 312 TFLOPS
if has("a800"):
return result(312e12, False) # A800 dense BF16 ≈ 312 TFLOPS (Ampere-class)
# Useful Ada data-center cards
if has("l40s", "l40-s", "l40 s"):
return result(362e12, False) # L40S dense BF16 ≈ 362 TFLOPS
if has(" l4", " l4 ", "nvidia l4", " l4-") or name.endswith(" l4"):
return result(121e12, False) # L4 dense BF16 ≈ 121 TFLOPS
# Other widely used Ampere data-center cards
if has("a30"):
return result(165e12, False) # A30 dense BF16 ≈ 165 TFLOPS
if has("a40"):
return result(149.7e12, False) # A40 dense BF16 ≈ 149.7 TFLOPS
# --- AMD CDNA accelerators ---
if has("mi355"):
return result(2.5e15, False) # MI355X dense BF16 ≈ 2.5 PFLOPS (5.0 PFLOPS sparse)
if has("mi325"):
return result(1.3074e15, False) # MI325X dense BF16 ≈ 1.3074 PFLOPS
if has("mi300x"):
return result(1.3074e15, False) # MI300X dense BF16 ≈ 1.3074 PFLOPS
if has("mi300a"):
return result(980.6e12, False) # MI300A dense BF16 ≈ 980.6 TFLOPS
if has("mi250x"):
return result(383e12, False) # MI250X dense BF16 ≈ 383 TFLOPS
if has("mi250"):
return result(362.1e12, False) # MI250 dense BF16 ≈ 362.1 TFLOPS
# --- Consumer RTX ---
if has("5090"):
return result(209.5e12, False) # RTX 5090 dense BF16 ≈ 209.5 TFLOPS (w/ FP32 accumulate)
if has("4090"):
return result(165.2e12, False) # RTX 4090 dense BF16 ≈ 165.2 TFLOPS (w/ FP32 accumulate)
if has("3090"):
return result(71e12, False) # RTX 3090 dense BF16 ≈ 71 TFLOPS (w/ FP32 accumulate)
# unknown
return result(1e12, True)

View File

@ -21,6 +21,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.mfu import get_promised_flops_per_gpu
from scripts.base_eval import evaluate_model
print_banner()
@ -125,6 +126,10 @@ print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
device_name, promised_flops_per_gpu, promised_is_estimated = get_promised_flops_per_gpu()
promised_flops_per_sec = promised_flops_per_gpu * ddp_world_size
print0(f"Detected GPU: {device_name} | peak BF16 TFLOPs (per GPU): {promised_flops_per_gpu / 1e12:.1f}{'*' if promised_is_estimated else ''}")
# -----------------------------------------------------------------------------
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
@ -286,11 +291,10 @@ for step in range(num_iterations + 1):
pct_done = 100 * step / num_iterations
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
mfu = 100 * flops_per_sec / promised_flops_per_sec # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f}{'*' if promised_is_estimated else ''} | total time: {total_training_time/60:.2f}m")
if step % 100 == 0:
wandb_run.log({
"step": step,
@ -327,7 +331,7 @@ get_report().log(section="Base model training", data=[
"Minimum validation bpb": min_val_bpb,
"Final validation bpb": val_bpb,
"CORE metric estimate": results["core_metric"],
"MFU %": f"{mfu:.2f}%",
"MFU %": f"{mfu:.2f}{'*' if promised_is_estimated else ''}%",
"Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time/60:.2f}m",
"Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB",

View File

@ -21,6 +21,7 @@ from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.checkpoint_manager import load_model
from nanochat.mfu import get_promised_flops_per_gpu
import torch.distributed as dist
from tasks.common import TaskMixture
@ -75,6 +76,10 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
device_name, promised_flops_per_gpu, promised_is_estimated = get_promised_flops_per_gpu()
promised_flops_per_sec = promised_flops_per_gpu * ddp_world_size
tflo_ps = promised_flops_per_gpu / 1e12
print0(f"Detected GPU: {device_name} | peak BF16 TFLOPs (per GPU): {tflo_ps:.1f}{'*' if promised_is_estimated else ''}")
token_bytes = get_token_bytes(device=device)
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
@ -250,11 +255,11 @@ while True:
pct_done = 100 * progress
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
mfu = 100 * flops_per_sec / promised_flops_per_sec # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
mfu_display = f"{mfu:.2f}{'*' if promised_is_estimated else ''}"
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu_display} | total time: {total_training_time/60:.2f}m")
if step % 10 == 0:
wandb_run.log({
"step": step,