From 3cb530cf97ea3ad814dc3eff92f608ffdceef01b Mon Sep 17 00:00:00 2001 From: Dipesh Babu Date: Mon, 9 Feb 2026 20:46:31 -0500 Subject: [PATCH 1/6] fix RoPE cache overflow with kv-cache by growing rope buffers --- nanochat/gpt.py | 43 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 208acd1..cb1fe7f 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -22,6 +22,8 @@ import torch.nn.functional as F from nanochat.common import get_dist_info, print0 from nanochat.optim import MuonAdamW, DistMuonAdamW +from typing import Optional + # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere from nanochat.flash_attention import flash_attn @@ -185,6 +187,37 @@ class GPT(nn.Module): self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("sin", sin, persistent=False) + def _ensure_rope_cache(self, needed_seq_len: int): + """ + Ensure rotary embedding cache (cos/sin) is long enough. + We grow the cache lazily to avoid evaluation crashes on long prompts. + """ + # existing cache length + cur = self.cos.size(1) + if needed_seq_len <= cur: + return + + # grow to next power-of-two for amortized behavior + new_len = 1 + while new_len < needed_seq_len: + new_len *= 2 + + head_dim = self.config.n_embd // self.config.n_head + device = self.cos.device + + cos, sin = self._precompute_rotary_embeddings( + seq_len=new_len, + head_dim=head_dim, + device=device, + ) + # keep dtype consistent with existing buffers + cos = cos.to(dtype=self.cos.dtype) + sin = sin.to(dtype=self.sin.dtype) + + # re-register buffers (safe overwrite) + self.register_buffer("cos", cos, persistent=False) + self.register_buffer("sin", sin, persistent=False) + @torch.no_grad() def init_weights(self): """ @@ -388,12 +421,14 @@ class GPT(nn.Module): def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): B, T = idx.size() - # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2)) - assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" + T0 = 0 if kv_cache is None else kv_cache.get_pos() + + # Ensure cache covers absolute positions [T0, T0+T) + self._ensure_rope_cache(T0 + T) + assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" - # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache - T0 = 0 if kv_cache is None else kv_cache.get_pos() + cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length # Forward the trunk of the Transformer From 5c5cff25a5c176b64f15321d765958d6d7fc5241 Mon Sep 17 00:00:00 2001 From: Dipesh Babu Date: Mon, 9 Feb 2026 20:51:15 -0500 Subject: [PATCH 2/6] remove unused import --- nanochat/gpt.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index cb1fe7f..1409d8c 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -12,7 +12,6 @@ Notable features: - Flash Attention 3 integration """ -from functools import partial from dataclasses import dataclass import torch @@ -22,7 +21,6 @@ import torch.nn.functional as F from nanochat.common import get_dist_info, print0 from nanochat.optim import MuonAdamW, DistMuonAdamW -from typing import Optional # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere from nanochat.flash_attention import flash_attn From 7e42702894903ed4f684468c6ce18e65bd9dd85b Mon Sep 17 00:00:00 2001 From: Dipesh Babu Date: Fri, 20 Feb 2026 17:29:37 -0500 Subject: [PATCH 3/6] fix: grow RoPE cache for KV-cache inference --- nanochat/gpt.py | 68 ++++++++++++++++++++++++++++--------------------- 1 file changed, 39 insertions(+), 29 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 1409d8c..0b717bb 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -185,36 +185,46 @@ class GPT(nn.Module): self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("sin", sin, persistent=False) - def _ensure_rope_cache(self, needed_seq_len: int): + def _ensure_rope_cache(self, needed_seq_len: int, device: torch.device): """ - Ensure rotary embedding cache (cos/sin) is long enough. - We grow the cache lazily to avoid evaluation crashes on long prompts. + Ensure rotary embedding cache (cos/sin) is long enough for absolute positions [0, needed_seq_len). + + We grow lazily to avoid crashes for long prompts / long KV-cache generation. + Growth is amortized by rounding up to the next power of two. + + NOTE: We avoid register_buffer() here; we simply overwrite the existing buffers. """ - # existing cache length - cur = self.cos.size(1) - if needed_seq_len <= cur: + cur_len = self.cos.size(1) + if needed_seq_len <= cur_len: return - # grow to next power-of-two for amortized behavior - new_len = 1 - while new_len < needed_seq_len: - new_len *= 2 + # Safety: mutating buffers during torch.compile tracing is unsafe. + try: + import torch._dynamo + if torch._dynamo.is_compiling(): + raise RuntimeError( + f"RoPE cache too small during torch.compile (need {needed_seq_len}, have {cur_len}). " + f"Increase initial rotary_seq_len or disable compile for generation." + ) + except Exception: + # torch._dynamo may not exist in older torch; ignore. + pass + + # Next power-of-two >= needed_seq_len + new_len = 1 << (needed_seq_len - 1).bit_length() head_dim = self.config.n_embd // self.config.n_head - device = self.cos.device - cos, sin = self._precompute_rotary_embeddings( - seq_len=new_len, - head_dim=head_dim, - device=device, - ) - # keep dtype consistent with existing buffers - cos = cos.to(dtype=self.cos.dtype) - sin = sin.to(dtype=self.sin.dtype) + seq_len=new_len, head_dim=head_dim, device=device) - # re-register buffers (safe overwrite) - self.register_buffer("cos", cos, persistent=False) - self.register_buffer("sin", sin, persistent=False) + # Preserve dtype/device invariants (precompute returns bf16 already) + cos = cos.to(dtype=self.cos.dtype, device=device) + sin = sin.to(dtype=self.sin.dtype, device=device) + + # Overwrite existing registered buffers (no re-register) + self.cos = cos + self.sin = sin + self.rotary_seq_len = new_len # keep metadata consistent @torch.no_grad() def init_weights(self): @@ -418,16 +428,16 @@ class GPT(nn.Module): def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): B, T = idx.size() - T0 = 0 if kv_cache is None else kv_cache.get_pos() - # Ensure cache covers absolute positions [T0, T0+T) - self._ensure_rope_cache(T0 + T) + # Ensure RoPE buffers cover absolute positions [T0, T0+T) + self._ensure_rope_cache(T0 + T, device=idx.device) - assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" - assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" - - cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length + # Now it's safe to slice + assert idx.device == self.cos.device, f"RoPE buffers and idx device mismatch: {idx.device} != {self.cos.device}" + assert self.cos.dtype == torch.bfloat16, "RoPE buffers must be bfloat16" + + cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # Forward the trunk of the Transformer x = self.transformer.wte(idx) # embed current token From c546a44001bdb63d8ed790934f39a2b199b858ba Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Tue, 24 Feb 2026 16:51:34 +0100 Subject: [PATCH 4/6] restore original assert messages --- nanochat/gpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 0b717bb..3e12424 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -434,8 +434,8 @@ class GPT(nn.Module): self._ensure_rope_cache(T0 + T, device=idx.device) # Now it's safe to slice - assert idx.device == self.cos.device, f"RoPE buffers and idx device mismatch: {idx.device} != {self.cos.device}" - assert self.cos.dtype == torch.bfloat16, "RoPE buffers must be bfloat16" + assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" + assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] From 2b55abe918628349f36dab5519c6351f2150a2e5 Mon Sep 17 00:00:00 2001 From: Dipesh Babu <59379458+dipeshbabu@users.noreply.github.com> Date: Tue, 24 Feb 2026 16:48:37 -0500 Subject: [PATCH 5/6] Refactor rotary embedding cache management --- nanochat/gpt.py | 48 ++++++++++++++++-------------------------------- 1 file changed, 16 insertions(+), 32 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 3e12424..a4e22ef 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -176,10 +176,9 @@ class GPT(nn.Module): kv_dim = config.n_kv_head * head_dim self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)}) # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. - # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, - # so let's just over-compute them by 10X, but assert fail if we ever reach that amount. - # In the future we can dynamically grow the cache, for now it's fine. - self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer? + # Precompute a reasonably large RoPE cache up front (cheap relative to model weights). + # The cache is also allowed to grow dynamically in forward() if generation exceeds this length. + self.rotary_seq_len = config.sequence_len * 10 head_dim = config.n_embd // config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint @@ -188,43 +187,28 @@ class GPT(nn.Module): def _ensure_rope_cache(self, needed_seq_len: int, device: torch.device): """ Ensure rotary embedding cache (cos/sin) is long enough for absolute positions [0, needed_seq_len). - + We grow lazily to avoid crashes for long prompts / long KV-cache generation. Growth is amortized by rounding up to the next power of two. - - NOTE: We avoid register_buffer() here; we simply overwrite the existing buffers. """ cur_len = self.cos.size(1) if needed_seq_len <= cur_len: return - - # Safety: mutating buffers during torch.compile tracing is unsafe. - try: - import torch._dynamo - if torch._dynamo.is_compiling(): - raise RuntimeError( - f"RoPE cache too small during torch.compile (need {needed_seq_len}, have {cur_len}). " - f"Increase initial rotary_seq_len or disable compile for generation." - ) - except Exception: - # torch._dynamo may not exist in older torch; ignore. - pass - + # Next power-of-two >= needed_seq_len new_len = 1 << (needed_seq_len - 1).bit_length() - + head_dim = self.config.n_embd // self.config.n_head - cos, sin = self._precompute_rotary_embeddings( - seq_len=new_len, head_dim=head_dim, device=device) - - # Preserve dtype/device invariants (precompute returns bf16 already) + cos, sin = self._precompute_rotary_embeddings(seq_len=new_len, head_dim=head_dim, device=device) + + # Preserve dtype/device invariants (precompute already returns bf16, but keep explicit) cos = cos.to(dtype=self.cos.dtype, device=device) sin = sin.to(dtype=self.sin.dtype, device=device) - - # Overwrite existing registered buffers (no re-register) + + # Overwrite existing registered buffers (keep same names, persistent=False property remains) self.cos = cos self.sin = sin - self.rotary_seq_len = new_len # keep metadata consistent + self.rotary_seq_len = new_len @torch.no_grad() def init_weights(self): @@ -429,14 +413,14 @@ class GPT(nn.Module): def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): B, T = idx.size() T0 = 0 if kv_cache is None else kv_cache.get_pos() - + # Ensure RoPE buffers cover absolute positions [T0, T0+T) self._ensure_rope_cache(T0 + T, device=idx.device) - - # Now it's safe to slice + assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" - + + # If kv cache exists, offset RoPE by current absolute position cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # Forward the trunk of the Transformer From b661d41ffd750ce67328e5bcd6cc626acfbb8497 Mon Sep 17 00:00:00 2001 From: Dipesh Babu <59379458+dipeshbabu@users.noreply.github.com> Date: Tue, 24 Feb 2026 17:38:02 -0500 Subject: [PATCH 6/6] Enhance rotary embedding cache management Refactor rotary embedding cache handling to improve memory management and error handling. --- nanochat/gpt.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index a4e22ef..e7e1955 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -177,39 +177,60 @@ class GPT(nn.Module): self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)}) # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. # Precompute a reasonably large RoPE cache up front (cheap relative to model weights). - # The cache is also allowed to grow dynamically in forward() if generation exceeds this length. + # The cache may also grow lazily in forward() if generation exceeds this length. self.rotary_seq_len = config.sequence_len * 10 + # Bound lazy growth to avoid unbounded memory usage during very long generation runs. + self.max_rotary_seq_len = max(self.rotary_seq_len, config.sequence_len * 64) + head_dim = config.n_embd // config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("sin", sin, persistent=False) - def _ensure_rope_cache(self, needed_seq_len: int, device: torch.device): + def _ensure_rope_cache(self, needed_seq_len: int): """ Ensure rotary embedding cache (cos/sin) is long enough for absolute positions [0, needed_seq_len). We grow lazily to avoid crashes for long prompts / long KV-cache generation. Growth is amortized by rounding up to the next power of two. + + Growth is bounded by self.max_rotary_seq_len to avoid unbounded memory usage. """ cur_len = self.cos.size(1) if needed_seq_len <= cur_len: return - # Next power-of-two >= needed_seq_len - new_len = 1 << (needed_seq_len - 1).bit_length() + if needed_seq_len > self.max_rotary_seq_len: + raise RuntimeError( + f"RoPE cache request exceeds max_rotary_seq_len: need {needed_seq_len}, " + f"have {cur_len}, cap {self.max_rotary_seq_len}. " + "Increase max_rotary_seq_len for longer-context generation." + ) + + # Safety: mutating buffers during torch.compile tracing is unsafe. + import torch._dynamo + if torch._dynamo.is_compiling(): + raise RuntimeError( + f"RoPE cache too small during torch.compile (need {needed_seq_len}, have {cur_len}). " + "Increase initial rotary_seq_len/max_rotary_seq_len or avoid compiled generation." + ) + + # Next power-of-two >= needed_seq_len (amortized growth), bounded by cap + new_len = min(self.max_rotary_seq_len, 1 << (needed_seq_len - 1).bit_length()) head_dim = self.config.n_embd // self.config.n_head + device = self.cos.device cos, sin = self._precompute_rotary_embeddings(seq_len=new_len, head_dim=head_dim, device=device) # Preserve dtype/device invariants (precompute already returns bf16, but keep explicit) cos = cos.to(dtype=self.cos.dtype, device=device) sin = sin.to(dtype=self.sin.dtype, device=device) - # Overwrite existing registered buffers (keep same names, persistent=False property remains) + # Overwrite existing registered buffers (persistent=False remains from initial registration) self.cos = cos self.sin = sin self.rotary_seq_len = new_len - + @torch.no_grad() def init_weights(self): """ @@ -415,7 +436,7 @@ class GPT(nn.Module): T0 = 0 if kv_cache is None else kv_cache.get_pos() # Ensure RoPE buffers cover absolute positions [T0, T0+T) - self._ensure_rope_cache(T0 + T, device=idx.device) + self._ensure_rope_cache(T0 + T) assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"