From dcc70667449509c537ae9e1984872f65bddb4dcc Mon Sep 17 00:00:00 2001 From: Matt Van Horn Date: Mon, 9 Mar 2026 07:49:56 -0700 Subject: [PATCH 1/2] dynamically grow RoPE cache for long eval prompts When training with small --max-seq-len (e.g. 256), the static 10x RoPE cache (2560 tokens) is too small for few-shot evaluation prompts that can exceed this limit. Instead of asserting, grow the cache on demand. Co-Authored-By: Claude Opus 4.6 --- nanochat/gpt.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 04ee5c5..168ebb8 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -396,12 +396,15 @@ class GPT(nn.Module): def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): B, T = idx.size() - # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2)) - assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" - assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" - assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}" # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache T0 = 0 if kv_cache is None else kv_cache.get_pos() + # Dynamically grow rotary cache if needed (e.g. few-shot eval prompts exceed training seq_len * 10) + if T0 + T > self.cos.size(1): + self.rotary_seq_len = T0 + T + head_dim = self.config.n_embd // self.config.n_head + self.cos, self.sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) + assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" + assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}" cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length # Forward the trunk of the Transformer From d097b1c958fd376f69d2c267489cb3cec5dd935f Mon Sep 17 00:00:00 2001 From: Matt Van Horn Date: Mon, 9 Mar 2026 11:26:08 -0700 Subject: [PATCH 2/2] fix: add 2x buffer to RoPE cache growth and update stale comment Addresses review feedback: grow to max(needed, 2x current) instead of exact size, so the cache doesn't regrow on every call. Also updates the __init__ comment that said dynamic growth was future work. Co-Authored-By: Claude Opus 4.6 --- nanochat/gpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 168ebb8..fb726be 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -184,7 +184,7 @@ class GPT(nn.Module): # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them by 10X, but assert fail if we ever reach that amount. - # In the future we can dynamically grow the cache, for now it's fine. + # The cache grows dynamically in forward() with a 2x buffer if eval prompts exceed this. self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer? head_dim = config.n_embd // config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) @@ -400,7 +400,7 @@ class GPT(nn.Module): T0 = 0 if kv_cache is None else kv_cache.get_pos() # Dynamically grow rotary cache if needed (e.g. few-shot eval prompts exceed training seq_len * 10) if T0 + T > self.cos.size(1): - self.rotary_seq_len = T0 + T + self.rotary_seq_len = max(T0 + T, self.cos.size(1) * 2) # 2x buffer to avoid regrowing every call head_dim = self.config.n_embd // self.config.n_head self.cos, self.sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"