diff --git a/nanochat/mfu.py b/nanochat/mfu.py new file mode 100644 index 0000000..157e68d --- /dev/null +++ b/nanochat/mfu.py @@ -0,0 +1,91 @@ +import torch + + +def get_promised_flops_per_gpu(): + """ + Return best-effort dense BF16 Tensor/Matrix peak FLOPs for the active GPU. + Returns: + tuple[str, float, bool]: (device_name, flops_per_gpu, is_estimated) + device_name is the CUDA-reported name for the active device. + flops_per_gpu is the per-device BF16 dense peak in FLOPs. + is_estimated is True when we fall back to heuristic defaults. + """ + device_index = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device_index) + name = props.name.lower() + device_name = props.name + + def result(flops, estimated): + return device_name, flops, estimated + + def has(*keywords: str) -> bool: + return any(k in name for k in keywords) + + # --- NVIDIA Blackwell --- + if has("gb200", "grace blackwell"): + return result(2.5e15, False) # GB200 dense BF16 ≈ 2.5 PFLOPS (5.0 PFLOPS sparse) + if has("b200"): + return result(2.25e15, False) # B200 dense BF16 ≈ 2.25 PFLOPS (4.5 PFLOPS sparse) + if has("b100"): + return result(1.8e15, False) # B100 dense BF16 ≈ 1.8 PFLOPS (3.5 PFLOPS sparse) + + # --- NVIDIA Hopper (H100/H200/H800) --- + if has("h200"): + if has("nvl", "pcie"): + return result(836e12, False) # H200 NVL/PCIe dense BF16 ≈ 836 TFLOPS + return result(989e12, False) # H200 SXM dense BF16 ≈ 989 TFLOPS + + if has("h100"): + if has("nvl"): + return result(835e12, False) # H100 NVL dense BF16 ≈ 835 TFLOPS + if has("pcie"): + return result(756e12, False) # H100 PCIe dense BF16 ≈ 756 TFLOPS + return result(989e12, False) # H100 SXM dense BF16 ≈ 989 TFLOPS + + if has("h800"): + if has("nvl"): + return result(989e12, False) # H800 NVLink dense BF16 ≈ 989 TFLOPS + return result(756e12, False) # H800 PCIe dense BF16 ≈ 756 TFLOPS + + # --- NVIDIA Ampere data center / export variants --- + if has("a100", "pg506"): + return result(312e12, False) # A100 SXM dense BF16 = 312 TFLOPS + if has("a800"): + return result(312e12, False) # A800 dense BF16 ≈ 312 TFLOPS (Ampere-class) + + # Useful Ada data-center cards + if has("l40s", "l40-s", "l40 s"): + return result(362e12, False) # L40S dense BF16 ≈ 362 TFLOPS + if has(" l4", " l4 ", "nvidia l4", " l4-") or name.endswith(" l4"): + return result(121e12, False) # L4 dense BF16 ≈ 121 TFLOPS + + # Other widely used Ampere data-center cards + if has("a30"): + return result(165e12, False) # A30 dense BF16 ≈ 165 TFLOPS + if has("a40"): + return result(149.7e12, False) # A40 dense BF16 ≈ 149.7 TFLOPS + + # --- AMD CDNA accelerators --- + if has("mi355"): + return result(2.5e15, False) # MI355X dense BF16 ≈ 2.5 PFLOPS (5.0 PFLOPS sparse) + if has("mi325"): + return result(1.3074e15, False) # MI325X dense BF16 ≈ 1.3074 PFLOPS + if has("mi300x"): + return result(1.3074e15, False) # MI300X dense BF16 ≈ 1.3074 PFLOPS + if has("mi300a"): + return result(980.6e12, False) # MI300A dense BF16 ≈ 980.6 TFLOPS + if has("mi250x"): + return result(383e12, False) # MI250X dense BF16 ≈ 383 TFLOPS + if has("mi250"): + return result(362.1e12, False) # MI250 dense BF16 ≈ 362.1 TFLOPS + + # --- Consumer RTX --- + if has("5090"): + return result(209.5e12, False) # RTX 5090 dense BF16 ≈ 209.5 TFLOPS (w/ FP32 accumulate) + if has("4090"): + return result(165.2e12, False) # RTX 4090 dense BF16 ≈ 165.2 TFLOPS (w/ FP32 accumulate) + if has("3090"): + return result(71e12, False) # RTX 3090 dense BF16 ≈ 71 TFLOPS (w/ FP32 accumulate) + + # unknown + return result(1e12, True) diff --git a/scripts/base_train.py b/scripts/base_train.py index ddd2c98..5d486d3 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -26,6 +26,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine +from nanochat.mfu import get_promised_flops_per_gpu from scripts.base_eval import evaluate_model print_banner() @@ -138,6 +139,10 @@ print0(f"Total number of training tokens: {total_tokens:,}") print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20 print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") +device_name, promised_flops_per_gpu, promised_is_estimated = get_promised_flops_per_gpu() +promised_flops_per_sec = promised_flops_per_gpu * ddp_world_size +print0(f"Detected GPU: {device_name} | peak BF16 TFLOPs (per GPU): {promised_flops_per_gpu / 1e12:.1f}{'*' if promised_is_estimated else ''}") + # ----------------------------------------------------------------------------- # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) @@ -296,11 +301,10 @@ for step in range(num_iterations + 1): pct_done = 100 * step / num_iterations tok_per_sec = int(total_batch_size / dt) flops_per_sec = num_flops_per_token * total_batch_size / dt - promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity - mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % + mfu = 100 * flops_per_sec / promised_flops_per_sec # in % if step > 10: total_training_time += dt # only count the time after the first 10 steps - print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") + print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f}{'*' if promised_is_estimated else ''} | total time: {total_training_time/60:.2f}m") if step % 100 == 0: wandb_run.log({ "step": step, @@ -337,7 +341,7 @@ get_report().log(section="Base model training", data=[ "Minimum validation bpb": min_val_bpb, "Final validation bpb": val_bpb, "CORE metric estimate": results.get("core_metric", None), - "MFU %": f"{mfu:.2f}%", + "MFU %": f"{mfu:.2f}{'*' if promised_is_estimated else ''}%", "Total training flops": f"{flops_so_far:e}", "Total training time": f"{total_training_time/60:.2f}m", "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB", diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 6c2b82f..f4ae249 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -21,6 +21,7 @@ from nanochat.tokenizer import get_token_bytes from nanochat.checkpoint_manager import save_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.checkpoint_manager import load_model +from nanochat.mfu import get_promised_flops_per_gpu import torch.distributed as dist from tasks.common import TaskMixture @@ -81,6 +82,10 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") +device_name, promised_flops_per_gpu, promised_is_estimated = get_promised_flops_per_gpu() +promised_flops_per_sec = promised_flops_per_gpu * ddp_world_size +tflo_ps = promised_flops_per_gpu / 1e12 +print0(f"Detected GPU: {device_name} | peak BF16 TFLOPs (per GPU): {tflo_ps:.1f}{'*' if promised_is_estimated else ''}") token_bytes = get_token_bytes(device=device) # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) @@ -270,11 +275,11 @@ while True: pct_done = 100 * progress tok_per_sec = int(total_batch_size / dt) flops_per_sec = num_flops_per_token * total_batch_size / dt - promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity - mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % + mfu = 100 * flops_per_sec / promised_flops_per_sec # in % if step > 10: total_training_time += dt # only count the time after the first 10 steps - print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") + mfu_display = f"{mfu:.2f}{'*' if promised_is_estimated else ''}" + print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu_display} | total time: {total_training_time/60:.2f}m") if step % 10 == 0: wandb_run.log({ "step": step,