diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 1409d8c..0b717bb 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -185,36 +185,46 @@ class GPT(nn.Module): self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("sin", sin, persistent=False) - def _ensure_rope_cache(self, needed_seq_len: int): + def _ensure_rope_cache(self, needed_seq_len: int, device: torch.device): """ - Ensure rotary embedding cache (cos/sin) is long enough. - We grow the cache lazily to avoid evaluation crashes on long prompts. + Ensure rotary embedding cache (cos/sin) is long enough for absolute positions [0, needed_seq_len). + + We grow lazily to avoid crashes for long prompts / long KV-cache generation. + Growth is amortized by rounding up to the next power of two. + + NOTE: We avoid register_buffer() here; we simply overwrite the existing buffers. """ - # existing cache length - cur = self.cos.size(1) - if needed_seq_len <= cur: + cur_len = self.cos.size(1) + if needed_seq_len <= cur_len: return - # grow to next power-of-two for amortized behavior - new_len = 1 - while new_len < needed_seq_len: - new_len *= 2 + # Safety: mutating buffers during torch.compile tracing is unsafe. + try: + import torch._dynamo + if torch._dynamo.is_compiling(): + raise RuntimeError( + f"RoPE cache too small during torch.compile (need {needed_seq_len}, have {cur_len}). " + f"Increase initial rotary_seq_len or disable compile for generation." + ) + except Exception: + # torch._dynamo may not exist in older torch; ignore. + pass + + # Next power-of-two >= needed_seq_len + new_len = 1 << (needed_seq_len - 1).bit_length() head_dim = self.config.n_embd // self.config.n_head - device = self.cos.device - cos, sin = self._precompute_rotary_embeddings( - seq_len=new_len, - head_dim=head_dim, - device=device, - ) - # keep dtype consistent with existing buffers - cos = cos.to(dtype=self.cos.dtype) - sin = sin.to(dtype=self.sin.dtype) + seq_len=new_len, head_dim=head_dim, device=device) - # re-register buffers (safe overwrite) - self.register_buffer("cos", cos, persistent=False) - self.register_buffer("sin", sin, persistent=False) + # Preserve dtype/device invariants (precompute returns bf16 already) + cos = cos.to(dtype=self.cos.dtype, device=device) + sin = sin.to(dtype=self.sin.dtype, device=device) + + # Overwrite existing registered buffers (no re-register) + self.cos = cos + self.sin = sin + self.rotary_seq_len = new_len # keep metadata consistent @torch.no_grad() def init_weights(self): @@ -418,16 +428,16 @@ class GPT(nn.Module): def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): B, T = idx.size() - T0 = 0 if kv_cache is None else kv_cache.get_pos() - # Ensure cache covers absolute positions [T0, T0+T) - self._ensure_rope_cache(T0 + T) + # Ensure RoPE buffers cover absolute positions [T0, T0+T) + self._ensure_rope_cache(T0 + T, device=idx.device) - assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" - assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" - - cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length + # Now it's safe to slice + assert idx.device == self.cos.device, f"RoPE buffers and idx device mismatch: {idx.device} != {self.cos.device}" + assert self.cos.dtype == torch.bfloat16, "RoPE buffers must be bfloat16" + + cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # Forward the trunk of the Transformer x = self.transformer.wte(idx) # embed current token