From 322eb6b86b9410e8c4347d4e20fc36ca80bcddad Mon Sep 17 00:00:00 2001 From: ademeure Date: Thu, 9 Apr 2026 11:29:04 +0000 Subject: [PATCH] Add profiling infrastructure (env-var controlled, nsys/ncu/torch profiler) - base_train.py: CUDA profiler + PyTorch profiler hooks gated by NANOCHAT_PROFILE_* env vars - profile_step.py: standalone single-step profiler with NVTX ranges and phase selection - LOCAL_STATE.md: documents local branch/file state before machine teardown Co-Authored-By: Claude Opus 4.6 (1M context) --- LOCAL_STATE.md | 35 +++++++++++ scripts/base_train.py | 50 +++++++++++++++ scripts/profile_step.py | 135 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 220 insertions(+) create mode 100644 LOCAL_STATE.md create mode 100644 scripts/profile_step.py diff --git a/LOCAL_STATE.md b/LOCAL_STATE.md new file mode 100644 index 00000000..b27c7913 --- /dev/null +++ b/LOCAL_STATE.md @@ -0,0 +1,35 @@ +# Local State — nanochat (karpathy fork) + +Documented 2026-04-09 before machine teardown. + +## Branch: fa3-flex-sdpa (current) +- Tracking: `fork/fa3-flex-sdpa` (ademeure/nanochat) — pushed and up to date +- 1 commit ahead of upstream master: `3d0dec5 FA3/FlexAttention/SDPA attention + PyTorch 2.11/CUDA 13.0` + +## Branch: pytorch-2.11-cu130 +- Tracking: `fork/pytorch-2.11-cu130` — pushed and up to date +- 2 commits ahead of master + +## Branch: pytorch-2.11-cu128-test +- **Local-only, no upstream** — but 0 commits ahead of master, just a branch pointer. No unique content. + +## Uncommitted changes (being committed now) + +### scripts/base_train.py +- Added env-var-controlled profiling hooks (`NANOCHAT_PROFILE_START`, `NANOCHAT_PROFILE_STOP`, `NANOCHAT_PROFILE_EXIT`, `NANOCHAT_TORCH_PROFILE_DIR`) +- CUDA profiler start/stop integration around training steps +- PyTorch profiler with tensorboard trace output +- Early exit after profiling completes +- This is a work-in-progress profiling integration — functional but may need further tuning + +### scripts/profile_step.py (new file) +- Standalone profiling script for a single training step (fwd/bwd/opt) +- Supports nsys and ncu profiling with NVTX ranges +- Usage: `nsys profile -o out python -m scripts.profile_step --depth 6` +- Supports `--phase {all,fwd,bwd,opt}` for targeted kernel analysis + +### profiles/ (NOT committed — binary nsys artifacts) +- `nsys_d32_full.nsys-rep` (1.6M) — nsys trace, depth=32 +- `nsys_d32_full.sqlite` (2.4M) — exported sqlite +- `nsys_d32_minimal.nsys-rep` (1.5M) — minimal nsys trace +- These are reproducible output artifacts, not committed to git diff --git a/scripts/base_train.py b/scripts/base_train.py index 20bbf6e5..e27a88b0 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -412,6 +412,34 @@ print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_l print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") +# Profiling hooks (env-var controlled, no-op by default) +_profile_start = int(os.environ.get("NANOCHAT_PROFILE_START", -1)) +_profile_stop = int(os.environ.get("NANOCHAT_PROFILE_STOP", -1)) +_profile_exit = int(os.environ.get("NANOCHAT_PROFILE_EXIT", -1)) +_torch_profile_dir = os.environ.get("NANOCHAT_TORCH_PROFILE_DIR", "") +if _profile_start >= 0: + print0(f"Profiling: start at step {_profile_start}, stop at step {_profile_stop}, exit at step {_profile_exit}") + +# PyTorch profiler (env-var controlled) +_torch_profiler = None +if _torch_profile_dir and _profile_start >= 0: + from torch.profiler import profile, ProfilerActivity, schedule, tensorboard_trace_handler + os.makedirs(_torch_profile_dir, exist_ok=True) + _torch_profiler = profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule( + wait=_profile_start, + warmup=0, + active=(_profile_stop - _profile_start + 1) if _profile_stop >= 0 else 1, + repeat=1, + ), + on_trace_ready=tensorboard_trace_handler(_torch_profile_dir), + record_shapes=True, + with_stack=True, + ) + _torch_profiler.start() + print0(f"PyTorch profiler: tracing steps {_profile_start}-{_profile_stop}, output to {_torch_profile_dir}") + # Go! while True: last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end @@ -504,6 +532,10 @@ while True: # ------------------------------------------------------------------------- # single training step + if step == _profile_start: + print0(f">>> CUDA profiler START at step {step}") + synchronize() + torch.cuda.cudart().cudaProfilerStart() # evaluate the gradient synchronize() t0 = time.time() @@ -579,10 +611,28 @@ while True: } wandb_run.log(log_data) + # profiling stop + if step == _profile_stop: + synchronize() + torch.cuda.cudart().cudaProfilerStop() + print0(f">>> CUDA profiler STOP after step {step}") + + # PyTorch profiler step + if _torch_profiler is not None: + _torch_profiler.step() + # state update first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step) step += 1 + # profiling early exit (checked after step increment) + if _profile_exit >= 0 and step > _profile_exit: + if _torch_profiler is not None: + _torch_profiler.stop() + print0(f">>> PyTorch profiler stopped, traces written to {_torch_profile_dir}") + print0(f">>> Early exit after step {_profile_exit} (profiling done)") + break + # The garbage collector is sadly a little bit overactive and for some poorly understood reason, # it spends ~500ms scanning for cycles quite frequently, just to end up cleaning up very few tiny objects each time. # So we manually manage and help it out here diff --git a/scripts/profile_step.py b/scripts/profile_step.py new file mode 100644 index 00000000..e911768b --- /dev/null +++ b/scripts/profile_step.py @@ -0,0 +1,135 @@ +""" +Profile a single training step of nanochat (forward + backward + optimizer). +Outputs nsys and ncu reports for detailed GPU kernel analysis. + +Usage: + # Nsight Systems (full timeline): + nsys profile -o profile_nsys_d6 python -m scripts.profile_step --depth 6 + nsys profile -o profile_nsys_d24 python -m scripts.profile_step --depth 24 + + # NCU (kernel-level, split by phase to keep reports manageable): + ncu --set full -o profile_ncu_d6_fwd python -m scripts.profile_step --depth 6 --phase fwd + ncu --set full -o profile_ncu_d6_bwd python -m scripts.profile_step --depth 6 --phase bwd + ncu --set full -o profile_ncu_d6_opt python -m scripts.profile_step --depth 6 --phase opt +""" +import os +os.environ["NANOCHAT_BASE_DIR"] = os.path.expanduser("~/.cache/nanochat") + +import argparse +import torch +import torch.cuda.nvtx as nvtx + +from nanochat.common import COMPUTE_DTYPE, print0 +from nanochat.gpt import GPT, GPTConfig + +parser = argparse.ArgumentParser() +parser.add_argument("--depth", type=int, default=6) +parser.add_argument("--phase", type=str, default="all", choices=["all", "fwd", "bwd", "opt"]) +parser.add_argument("--seq-len", type=int, default=1024) +parser.add_argument("--batch-size", type=int, default=16) +parser.add_argument("--head-dim", type=int, default=64) +parser.add_argument("--aspect-ratio", type=int, default=48) +args = parser.parse_args() + +# --------------------------------------------------------------------------- +# Setup +device = torch.device("cuda") +torch.manual_seed(42) +torch.set_float32_matmul_precision("high") + +# Build model (same logic as base_train.py) +base_dim = args.depth * args.aspect_ratio +model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim +num_heads = model_dim // args.head_dim +config = GPTConfig( + sequence_len=args.seq_len, vocab_size=32768, + n_layer=args.depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, + window_pattern="SSSL", +) +with torch.device("meta"): + model = GPT(config) +model.to_empty(device=device) +model.init_weights() +model = torch.compile(model, dynamic=False) +model.train() + +optimizer = model.setup_optimizer( + unembedding_lr=0.01, embedding_lr=0.01, scalar_lr=0.01, + matrix_lr=0.01, weight_decay=0.1, +) + +n_params = sum(p.numel() for p in model.parameters()) +print0(f"Model: depth={args.depth} dim={model_dim} heads={num_heads} params={n_params:,}") +print0(f"Batch: {args.batch_size} x {args.seq_len} = {args.batch_size * args.seq_len:,} tokens") + +# Dummy data +x = torch.randint(0, config.vocab_size, (args.batch_size, args.seq_len), device=device) +y = torch.randint(0, config.vocab_size, (args.batch_size, args.seq_len), device=device) + +# --------------------------------------------------------------------------- +# Warmup (let torch.compile JIT) +print0("Warming up (torch.compile)...") +for _ in range(3): + loss = model(x, y) + loss.backward() + optimizer.step() + model.zero_grad(set_to_none=True) +torch.cuda.synchronize() +print0("Warmup done. Profiling...") + +# --------------------------------------------------------------------------- +# Profiled step — NVTX ranges for nsys, CUDA ranges for ncu + +def do_forward(): + nvtx.range_push("forward") + loss = model(x, y) + torch.cuda.synchronize() + nvtx.range_pop() + return loss + +def do_backward(loss): + nvtx.range_push("backward") + loss.backward() + torch.cuda.synchronize() + nvtx.range_pop() + +def do_optimizer(): + nvtx.range_push("optimizer") + optimizer.step() + torch.cuda.synchronize() + nvtx.range_pop() + model.zero_grad(set_to_none=True) + +if args.phase == "fwd": + torch.cuda.cudart().cudaProfilerStart() + loss = do_forward() + torch.cuda.cudart().cudaProfilerStop() + print0(f"Forward done. loss={loss.item():.4f}") + +elif args.phase == "bwd": + loss = model(x, y) # unprofiled forward + torch.cuda.synchronize() + torch.cuda.cudart().cudaProfilerStart() + do_backward(loss) + torch.cuda.cudart().cudaProfilerStop() + print0("Backward done.") + +elif args.phase == "opt": + loss = model(x, y) # unprofiled forward+backward + loss.backward() + torch.cuda.synchronize() + torch.cuda.cudart().cudaProfilerStart() + do_optimizer() + torch.cuda.cudart().cudaProfilerStop() + print0("Optimizer done.") + +else: # "all" + torch.cuda.cudart().cudaProfilerStart() + loss = do_forward() + do_backward(loss) + do_optimizer() + torch.cuda.cudart().cudaProfilerStop() + print0(f"Full step done. loss={loss.item():.4f}") + +peak_mb = torch.cuda.max_memory_allocated() / 1024 / 1024 +print0(f"Peak VRAM: {peak_mb:.0f} MiB")