From 2dc85662c3f0dbc960ff875b87a72d2dfed694d8 Mon Sep 17 00:00:00 2001 From: Dipesh Babu Date: Wed, 5 Nov 2025 21:22:35 -0500 Subject: [PATCH 1/4] fix: safe DDP cleanup (check initialized PG, not just env) --- nanochat/common.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/nanochat/common.py b/nanochat/common.py index d4a9828..7efd60d 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -113,12 +113,24 @@ def print_banner(): """ print0(banner) -def is_ddp(): - # TODO is there a proper way - return int(os.environ.get('RANK', -1)) != -1 +def is_ddp_requested() -> bool: + """ + True if launched by torchrun (env present), even before init. + Used to decide whether we *should* initialize a PG. + """ + return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE")) + +def is_ddp_initialized() -> bool: + """ + True if torch.distributed is available and the process group is initialized. + Used at cleanup to avoid destroying a non-existent PG. + """ + return dist.is_available() and dist.is_initialized() def get_dist_info(): - if is_ddp(): + if is_ddp_requested(): + # We rely on torchrun's env to decide if we SHOULD init. + # (Initialization itself happens in compute init.) assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) ddp_rank = int(os.environ['RANK']) ddp_local_rank = int(os.environ['LOCAL_RANK']) @@ -159,8 +171,8 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA - ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() - if ddp and device_type == "cuda": + is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() + if is_ddp_requested and device_type == "cuda": device = torch.device("cuda", ddp_local_rank) torch.cuda.set_device(device) # make "cuda" default to this device dist.init_process_group(backend="nccl", device_id=device) @@ -171,11 +183,11 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps if ddp_rank == 0: logger.info(f"Distributed world size: {ddp_world_size}") - return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device + return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device def compute_cleanup(): """Companion function to compute_init, to clean things up before script exit""" - if is_ddp(): + if is_ddp_initialized(): dist.destroy_process_group() class DummyWandb: From beb34ac43cdc6d331111eeda0ec85f1a0ac20754 Mon Sep 17 00:00:00 2001 From: Dipesh Babu Date: Sat, 31 Jan 2026 19:18:48 -0500 Subject: [PATCH 2/4] fix: correct LR warmdown step range --- scripts/base_train.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index 7ed6330..cc858d4 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -238,15 +238,28 @@ x, y, dataloader_state_dict = next(train_loader) # kick off load of the very fir # Learning rate scheduler def get_lr_multiplier(it): - warmup_iters = round(args.warmup_ratio * num_iterations) - warmdown_iters = round(args.warmdown_ratio * num_iterations) - if it < warmup_iters: + # Note: optimizer steps run for it in [0, num_iterations-1] + warmup_iters = int(round(args.warmup_ratio * num_iterations)) + warmdown_iters = int(round(args.warmdown_ratio * num_iterations)) + + # Warmup (avoid division by zero when warmup_iters == 0) + if warmup_iters > 0 and it < warmup_iters: return (it + 1) / warmup_iters - elif it <= num_iterations - warmdown_iters: - return 1.0 - else: - progress = (num_iterations - it) / warmdown_iters - return progress * 1.0 + (1 - progress) * args.final_lr_frac + + # Warmdown should cover the last `warmdown_iters` optimizer steps: + # it in [num_iterations - warmdown_iters, num_iterations - 1] + if warmdown_iters > 0: + warmdown_start = num_iterations - warmdown_iters + # Ensure warmdown doesn't start before warmup ends (prevents overlap weirdness) + warmdown_start = max(warmdown_start, warmup_iters) + + if it >= warmdown_start: + # progress: 1.0 at warmdown_start, 0.0 at last optimizer step (num_iterations - 1) + span = max(1, (num_iterations - 1) - warmdown_start) # denom >= 1 + progress = (num_iterations - 1 - it) / span + return progress * 1.0 + (1.0 - progress) * args.final_lr_frac + + return 1.0 # Momentum scheduler for Muon optimizer def get_muon_momentum(it): From 3675a44cd6acbbafab7211d09b51646582ef4ae9 Mon Sep 17 00:00:00 2001 From: Dipesh Babu Date: Mon, 9 Feb 2026 20:46:31 -0500 Subject: [PATCH 3/4] fix RoPE cache overflow with kv-cache by growing rope buffers --- nanochat/gpt.py | 43 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 208acd1..cb1fe7f 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -22,6 +22,8 @@ import torch.nn.functional as F from nanochat.common import get_dist_info, print0 from nanochat.optim import MuonAdamW, DistMuonAdamW +from typing import Optional + # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere from nanochat.flash_attention import flash_attn @@ -185,6 +187,37 @@ class GPT(nn.Module): self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("sin", sin, persistent=False) + def _ensure_rope_cache(self, needed_seq_len: int): + """ + Ensure rotary embedding cache (cos/sin) is long enough. + We grow the cache lazily to avoid evaluation crashes on long prompts. + """ + # existing cache length + cur = self.cos.size(1) + if needed_seq_len <= cur: + return + + # grow to next power-of-two for amortized behavior + new_len = 1 + while new_len < needed_seq_len: + new_len *= 2 + + head_dim = self.config.n_embd // self.config.n_head + device = self.cos.device + + cos, sin = self._precompute_rotary_embeddings( + seq_len=new_len, + head_dim=head_dim, + device=device, + ) + # keep dtype consistent with existing buffers + cos = cos.to(dtype=self.cos.dtype) + sin = sin.to(dtype=self.sin.dtype) + + # re-register buffers (safe overwrite) + self.register_buffer("cos", cos, persistent=False) + self.register_buffer("sin", sin, persistent=False) + @torch.no_grad() def init_weights(self): """ @@ -388,12 +421,14 @@ class GPT(nn.Module): def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): B, T = idx.size() - # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2)) - assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" + T0 = 0 if kv_cache is None else kv_cache.get_pos() + + # Ensure cache covers absolute positions [T0, T0+T) + self._ensure_rope_cache(T0 + T) + assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" - # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache - T0 = 0 if kv_cache is None else kv_cache.get_pos() + cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length # Forward the trunk of the Transformer From 1bf1fdaa0d484678dd409f2d56a5777824b7e060 Mon Sep 17 00:00:00 2001 From: Dipesh Babu Date: Mon, 9 Feb 2026 20:51:15 -0500 Subject: [PATCH 4/4] remove unused import --- nanochat/gpt.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index cb1fe7f..1409d8c 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -12,7 +12,6 @@ Notable features: - Flash Attention 3 integration """ -from functools import partial from dataclasses import dataclass import torch @@ -22,7 +21,6 @@ import torch.nn.functional as F from nanochat.common import get_dist_info, print0 from nanochat.optim import MuonAdamW, DistMuonAdamW -from typing import Optional # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere from nanochat.flash_attention import flash_attn