diff --git a/nanochat/common.py b/nanochat/common.py index 22559ce..faf9144 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -200,3 +200,52 @@ class DummyWandb: pass def finish(self): pass + +# hardcoded BF16 peak flops for NVIDIA A100, H100, H200, B200 GPU and AMD MI250, MI300X, MI325X, MI355X and Intel PVC +# inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py +def get_peak_flops(device_name: str) -> float: + if "A100" in device_name: + # data from https://www.nvidia.com/en-us/data-center/a100/ + return 312e12 + elif "H100" in device_name: + # data from https://www.nvidia.com/en-us/data-center/h100/ + # NOTE: Specifications are one-half lower without sparsity. + if "NVL" in device_name: + return 835e12 + elif "PCIe" in device_name: + return 756e12 + else: # for H100 SXM and other variants + return 989e12 + elif "H200" in device_name: + # data from https://www.nvidia.com/en-us/data-center/h200/ + return 989e12 + elif "B200" in device_name: + # data from https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703 + return 2.25e15 + elif "MI355X" in device_name: + # MI355X data from https://www.amd.com/en/products/accelerators/instinct/mi350/mi355x.html + return 2500e12 + elif "MI300X" in device_name or "MI325X" in device_name: + # MI300X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html + # MI325X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi325x.html + return 1300e12 + elif "MI250X" in device_name: + # data from https://www.amd.com/en/products/accelerators/instinct/mi200/mi250x.html (per GCD) + return 191.5e12 + elif "Data Center GPU Max 1550" in device_name: + # Also known as Ponte Vecchio (PVC). + # data from https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html + # Dot Product Accumulate Systolic (DPAS): + # - Freq: 1300MHz + # - #ops: 512 + # Full EU mode (i.e. 512 max compute units): 340.8 TFLOPS (BF16) + # Standard EU mode (i.e. 448 max compute units): 298.2 TFLOPS (BF16) + max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units + return 512 * max_comp_units * 1300 * 10**6 + elif "l40s" in device_name: + # data from: "https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413" + return 362e12 + + else: # for other GPU types, assume A100 + logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100") + return 312e12 diff --git a/scripts/base_train.py b/scripts/base_train.py index c61986e..e051f99 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -22,7 +22,7 @@ import torch from nanochat.gpt import GPT, GPTConfig from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb @@ -82,6 +82,12 @@ master_process = ddp_rank == 0 # this process will do logging, checkpointing etc autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 +if device_type == "cuda": + gpu_device_name = torch.cuda.get_device_name(0) + gpu_peak_flops = get_peak_flops(gpu_device_name) + print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}") +else: + gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS # wandb logging init use_dummy_wandb = args.run == "dummy" or not master_process @@ -395,8 +401,7 @@ while True: pct_done = 100 * step / num_iterations tok_per_sec = int(args.total_batch_size / dt) flops_per_sec = num_flops_per_token * args.total_batch_size / dt - promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity - mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % + mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) if step > 10: total_training_time += dt # only count the time after the first 10 steps # Calculate ETA based on average time per step (excluding first 10 steps)