From 3cb530cf97ea3ad814dc3eff92f608ffdceef01b Mon Sep 17 00:00:00 2001 From: Dipesh Babu Date: Mon, 9 Feb 2026 20:46:31 -0500 Subject: [PATCH] fix RoPE cache overflow with kv-cache by growing rope buffers --- nanochat/gpt.py | 43 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 208acd1..cb1fe7f 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -22,6 +22,8 @@ import torch.nn.functional as F from nanochat.common import get_dist_info, print0 from nanochat.optim import MuonAdamW, DistMuonAdamW +from typing import Optional + # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere from nanochat.flash_attention import flash_attn @@ -185,6 +187,37 @@ class GPT(nn.Module): self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("sin", sin, persistent=False) + def _ensure_rope_cache(self, needed_seq_len: int): + """ + Ensure rotary embedding cache (cos/sin) is long enough. + We grow the cache lazily to avoid evaluation crashes on long prompts. + """ + # existing cache length + cur = self.cos.size(1) + if needed_seq_len <= cur: + return + + # grow to next power-of-two for amortized behavior + new_len = 1 + while new_len < needed_seq_len: + new_len *= 2 + + head_dim = self.config.n_embd // self.config.n_head + device = self.cos.device + + cos, sin = self._precompute_rotary_embeddings( + seq_len=new_len, + head_dim=head_dim, + device=device, + ) + # keep dtype consistent with existing buffers + cos = cos.to(dtype=self.cos.dtype) + sin = sin.to(dtype=self.sin.dtype) + + # re-register buffers (safe overwrite) + self.register_buffer("cos", cos, persistent=False) + self.register_buffer("sin", sin, persistent=False) + @torch.no_grad() def init_weights(self): """ @@ -388,12 +421,14 @@ class GPT(nn.Module): def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): B, T = idx.size() - # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2)) - assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" + T0 = 0 if kv_cache is None else kv_cache.get_pos() + + # Ensure cache covers absolute positions [T0, T0+T) + self._ensure_rope_cache(T0 + T) + assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" - # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache - T0 = 0 if kv_cache is None else kv_cache.get_pos() + cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length # Forward the trunk of the Transformer