Add automatic GPU detection for accurate MFU calculation

This commit is contained in:
Ozamatash 2025-10-17 19:49:21 +00:00
parent d6d86cbf4c
commit ebd9c6171a
3 changed files with 53 additions and 6 deletions

View File

@ -126,6 +126,51 @@ def compute_cleanup():
if is_ddp():
dist.destroy_process_group()
def get_peak_flops(device=None) -> float:
"""
Get peak FLOPS for the GPU.
Returns peak BF16 Tensor FLOPS with FP32 accumulation (non-sparse) for the current GPU.
This is the metric that corresponds to PyTorch bfloat16 training performance.
Defaults to H100 if GPU is unknown.
"""
GPU_PEAK_FLOPS = {
"H200": 989e12,
"B200": 2.25e15,
"H100 NVL": 835e12,
"H100 PCIe": 756e12,
"H100": 989e12,
"A100": 312e12,
"6000 Blackwell Max-Q": 438.9e12,
"6000 Blackwell": 503.8e12,
"6000 Ada": 364e12,
"A6000": 154.8e12,
"RTX 5090": 209.5e12,
"RTX 4090": 165.2e12,
"RTX 3090": 71e12,
"L40S": 362e12,
}
try:
if device is None:
device = torch.device("cuda:0")
device_name = torch.cuda.get_device_name(device)
# Match GPU by substring (case-insensitive). Sort by length to check specific names first
# e.g., "6000 Blackwell Max-Q" checked before "6000 Blackwell"
for gpu_key, flops in sorted(GPU_PEAK_FLOPS.items(), key=lambda x: len(x[0]), reverse=True):
if gpu_key.lower() in device_name.lower():
if int(os.environ.get('RANK', 0)) == 0:
logger.info(f"Detected GPU: {device_name} -> using {flops/1e12:.1f} TFLOPS (BF16 Tensor)")
return flops
# Unknown GPU: warn and default to H100
logger.warning(f"Unknown GPU '{device_name}', defaulting to H100 peak FLOPS (989e12)")
return 989e12
except Exception as e:
logger.warning(f"Could not detect GPU, defaulting to H100: {e}")
return 989e12
class DummyWandb:
"""Useful if we wish to not use wandb but have all the same signatures"""
def __init__(self):

View File

@ -16,7 +16,7 @@ import torch
from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, get_peak_flops
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
@ -104,6 +104,8 @@ num_params = sum(p.numel() for p in model.parameters())
print0(f"Number of parameters: {num_params:,}")
num_flops_per_token = model.estimate_flops()
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
peak_flops_per_gpu = get_peak_flops(device)
promised_flops_per_sec = peak_flops_per_gpu * ddp_world_size
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
@ -286,8 +288,7 @@ for step in range(num_iterations + 1):
pct_done = 100 * step / num_iterations
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
mfu = 100 * flops_per_sec / promised_flops_per_sec # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")

View File

@ -16,7 +16,7 @@ import time
import wandb
import torch
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, get_peak_flops
from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
@ -76,6 +76,8 @@ print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tok
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
token_bytes = get_token_bytes(device=device)
peak_flops_per_gpu = get_peak_flops(device)
promised_flops_per_sec = peak_flops_per_gpu * ddp_world_size
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
@ -250,8 +252,7 @@ while True:
pct_done = 100 * progress
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
mfu = 100 * flops_per_sec / promised_flops_per_sec # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")