diff --git a/nanochat/gpt.py b/nanochat/gpt.py index a4e22efe..e7e19550 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -177,39 +177,60 @@ class GPT(nn.Module): self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)}) # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. # Precompute a reasonably large RoPE cache up front (cheap relative to model weights). - # The cache is also allowed to grow dynamically in forward() if generation exceeds this length. + # The cache may also grow lazily in forward() if generation exceeds this length. self.rotary_seq_len = config.sequence_len * 10 + # Bound lazy growth to avoid unbounded memory usage during very long generation runs. + self.max_rotary_seq_len = max(self.rotary_seq_len, config.sequence_len * 64) + head_dim = config.n_embd // config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("sin", sin, persistent=False) - def _ensure_rope_cache(self, needed_seq_len: int, device: torch.device): + def _ensure_rope_cache(self, needed_seq_len: int): """ Ensure rotary embedding cache (cos/sin) is long enough for absolute positions [0, needed_seq_len). We grow lazily to avoid crashes for long prompts / long KV-cache generation. Growth is amortized by rounding up to the next power of two. + + Growth is bounded by self.max_rotary_seq_len to avoid unbounded memory usage. """ cur_len = self.cos.size(1) if needed_seq_len <= cur_len: return - # Next power-of-two >= needed_seq_len - new_len = 1 << (needed_seq_len - 1).bit_length() + if needed_seq_len > self.max_rotary_seq_len: + raise RuntimeError( + f"RoPE cache request exceeds max_rotary_seq_len: need {needed_seq_len}, " + f"have {cur_len}, cap {self.max_rotary_seq_len}. " + "Increase max_rotary_seq_len for longer-context generation." + ) + + # Safety: mutating buffers during torch.compile tracing is unsafe. + import torch._dynamo + if torch._dynamo.is_compiling(): + raise RuntimeError( + f"RoPE cache too small during torch.compile (need {needed_seq_len}, have {cur_len}). " + "Increase initial rotary_seq_len/max_rotary_seq_len or avoid compiled generation." + ) + + # Next power-of-two >= needed_seq_len (amortized growth), bounded by cap + new_len = min(self.max_rotary_seq_len, 1 << (needed_seq_len - 1).bit_length()) head_dim = self.config.n_embd // self.config.n_head + device = self.cos.device cos, sin = self._precompute_rotary_embeddings(seq_len=new_len, head_dim=head_dim, device=device) # Preserve dtype/device invariants (precompute already returns bf16, but keep explicit) cos = cos.to(dtype=self.cos.dtype, device=device) sin = sin.to(dtype=self.sin.dtype, device=device) - # Overwrite existing registered buffers (keep same names, persistent=False property remains) + # Overwrite existing registered buffers (persistent=False remains from initial registration) self.cos = cos self.sin = sin self.rotary_seq_len = new_len - + @torch.no_grad() def init_weights(self): """ @@ -415,7 +436,7 @@ class GPT(nn.Module): T0 = 0 if kv_cache is None else kv_cache.get_pos() # Ensure RoPE buffers cover absolute positions [T0, T0+T) - self._ensure_rope_cache(T0 + T, device=idx.device) + self._ensure_rope_cache(T0 + T) assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"