From ebd9c6171afce426aceebc9fa3161b6841a61c7f Mon Sep 17 00:00:00 2001 From: Ozamatash Date: Fri, 17 Oct 2025 19:49:21 +0000 Subject: [PATCH] Add automatic GPU detection for accurate MFU calculation --- nanochat/common.py | 45 +++++++++++++++++++++++++++++++++++++++++++ scripts/base_train.py | 7 ++++--- scripts/mid_train.py | 7 ++++--- 3 files changed, 53 insertions(+), 6 deletions(-) diff --git a/nanochat/common.py b/nanochat/common.py index 8b10df9..ff92f9a 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -126,6 +126,51 @@ def compute_cleanup(): if is_ddp(): dist.destroy_process_group() +def get_peak_flops(device=None) -> float: + """ + Get peak FLOPS for the GPU. + Returns peak BF16 Tensor FLOPS with FP32 accumulation (non-sparse) for the current GPU. + This is the metric that corresponds to PyTorch bfloat16 training performance. + Defaults to H100 if GPU is unknown. + """ + GPU_PEAK_FLOPS = { + "H200": 989e12, + "B200": 2.25e15, + "H100 NVL": 835e12, + "H100 PCIe": 756e12, + "H100": 989e12, + "A100": 312e12, + "6000 Blackwell Max-Q": 438.9e12, + "6000 Blackwell": 503.8e12, + "6000 Ada": 364e12, + "A6000": 154.8e12, + "RTX 5090": 209.5e12, + "RTX 4090": 165.2e12, + "RTX 3090": 71e12, + "L40S": 362e12, + } + + try: + if device is None: + device = torch.device("cuda:0") + device_name = torch.cuda.get_device_name(device) + + # Match GPU by substring (case-insensitive). Sort by length to check specific names first + # e.g., "6000 Blackwell Max-Q" checked before "6000 Blackwell" + for gpu_key, flops in sorted(GPU_PEAK_FLOPS.items(), key=lambda x: len(x[0]), reverse=True): + if gpu_key.lower() in device_name.lower(): + if int(os.environ.get('RANK', 0)) == 0: + logger.info(f"Detected GPU: {device_name} -> using {flops/1e12:.1f} TFLOPS (BF16 Tensor)") + return flops + + # Unknown GPU: warn and default to H100 + logger.warning(f"Unknown GPU '{device_name}', defaulting to H100 peak FLOPS (989e12)") + return 989e12 + + except Exception as e: + logger.warning(f"Could not detect GPU, defaulting to H100: {e}") + return 989e12 + class DummyWandb: """Useful if we wish to not use wandb but have all the same signatures""" def __init__(self): diff --git a/scripts/base_train.py b/scripts/base_train.py index b691ed4..3875c05 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -16,7 +16,7 @@ import torch from nanochat.gpt import GPT, GPTConfig from nanochat.dataloader import tokenizing_distributed_data_loader -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, get_peak_flops from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint from nanochat.loss_eval import evaluate_bpb @@ -104,6 +104,8 @@ num_params = sum(p.numel() for p in model.parameters()) print0(f"Number of parameters: {num_params:,}") num_flops_per_token = model.estimate_flops() print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") +peak_flops_per_gpu = get_peak_flops(device) +promised_flops_per_sec = peak_flops_per_gpu * ddp_world_size # Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order) assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0 @@ -286,8 +288,7 @@ for step in range(num_iterations + 1): pct_done = 100 * step / num_iterations tok_per_sec = int(world_tokens_per_fwdbwd / dt) flops_per_sec = num_flops_per_token * total_batch_size / dt - promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity - mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % + mfu = 100 * flops_per_sec / promised_flops_per_sec # in % if step > 10: total_training_time += dt # only count the time after the first 10 steps print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 90ab954..0222e4b 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -16,7 +16,7 @@ import time import wandb import torch -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, get_peak_flops from nanochat.tokenizer import get_token_bytes from nanochat.checkpoint_manager import save_checkpoint from nanochat.loss_eval import evaluate_bpb @@ -76,6 +76,8 @@ print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tok print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") token_bytes = get_token_bytes(device=device) +peak_flops_per_gpu = get_peak_flops(device) +promised_flops_per_sec = peak_flops_per_gpu * ddp_world_size # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) @@ -250,8 +252,7 @@ while True: pct_done = 100 * progress tok_per_sec = int(world_tokens_per_fwdbwd / dt) flops_per_sec = num_flops_per_token * total_batch_size / dt - promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity - mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % + mfu = 100 * flops_per_sec / promised_flops_per_sec # in % if step > 10: total_training_time += dt # only count the time after the first 10 steps print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")