mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-10 04:29:50 +00:00
Add automatic GPU detection for accurate MFU calculation
This commit is contained in:
parent
d6d86cbf4c
commit
ebd9c6171a
|
|
@ -126,6 +126,51 @@ def compute_cleanup():
|
|||
if is_ddp():
|
||||
dist.destroy_process_group()
|
||||
|
||||
def get_peak_flops(device=None) -> float:
|
||||
"""
|
||||
Get peak FLOPS for the GPU.
|
||||
Returns peak BF16 Tensor FLOPS with FP32 accumulation (non-sparse) for the current GPU.
|
||||
This is the metric that corresponds to PyTorch bfloat16 training performance.
|
||||
Defaults to H100 if GPU is unknown.
|
||||
"""
|
||||
GPU_PEAK_FLOPS = {
|
||||
"H200": 989e12,
|
||||
"B200": 2.25e15,
|
||||
"H100 NVL": 835e12,
|
||||
"H100 PCIe": 756e12,
|
||||
"H100": 989e12,
|
||||
"A100": 312e12,
|
||||
"6000 Blackwell Max-Q": 438.9e12,
|
||||
"6000 Blackwell": 503.8e12,
|
||||
"6000 Ada": 364e12,
|
||||
"A6000": 154.8e12,
|
||||
"RTX 5090": 209.5e12,
|
||||
"RTX 4090": 165.2e12,
|
||||
"RTX 3090": 71e12,
|
||||
"L40S": 362e12,
|
||||
}
|
||||
|
||||
try:
|
||||
if device is None:
|
||||
device = torch.device("cuda:0")
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
|
||||
# Match GPU by substring (case-insensitive). Sort by length to check specific names first
|
||||
# e.g., "6000 Blackwell Max-Q" checked before "6000 Blackwell"
|
||||
for gpu_key, flops in sorted(GPU_PEAK_FLOPS.items(), key=lambda x: len(x[0]), reverse=True):
|
||||
if gpu_key.lower() in device_name.lower():
|
||||
if int(os.environ.get('RANK', 0)) == 0:
|
||||
logger.info(f"Detected GPU: {device_name} -> using {flops/1e12:.1f} TFLOPS (BF16 Tensor)")
|
||||
return flops
|
||||
|
||||
# Unknown GPU: warn and default to H100
|
||||
logger.warning(f"Unknown GPU '{device_name}', defaulting to H100 peak FLOPS (989e12)")
|
||||
return 989e12
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not detect GPU, defaulting to H100: {e}")
|
||||
return 989e12
|
||||
|
||||
class DummyWandb:
|
||||
"""Useful if we wish to not use wandb but have all the same signatures"""
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import torch
|
|||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, get_peak_flops
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -104,6 +104,8 @@ num_params = sum(p.numel() for p in model.parameters())
|
|||
print0(f"Number of parameters: {num_params:,}")
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
||||
peak_flops_per_gpu = get_peak_flops(device)
|
||||
promised_flops_per_sec = peak_flops_per_gpu * ddp_world_size
|
||||
|
||||
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
|
||||
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
|
||||
|
|
@ -286,8 +288,7 @@ for step in range(num_iterations + 1):
|
|||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import time
|
|||
import wandb
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, get_peak_flops
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -76,6 +76,8 @@ print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tok
|
|||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
peak_flops_per_gpu = get_peak_flops(device)
|
||||
promised_flops_per_sec = peak_flops_per_gpu * ddp_world_size
|
||||
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
|
|
@ -250,8 +252,7 @@ while True:
|
|||
pct_done = 100 * progress
|
||||
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user