fix: grow RoPE cache for KV-cache inference

This commit is contained in:
Dipesh Babu 2026-02-20 17:29:37 -05:00
parent 5c5cff25a5
commit 7e42702894

View File

@ -185,36 +185,46 @@ class GPT(nn.Module):
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)
def _ensure_rope_cache(self, needed_seq_len: int):
def _ensure_rope_cache(self, needed_seq_len: int, device: torch.device):
"""
Ensure rotary embedding cache (cos/sin) is long enough.
We grow the cache lazily to avoid evaluation crashes on long prompts.
Ensure rotary embedding cache (cos/sin) is long enough for absolute positions [0, needed_seq_len).
We grow lazily to avoid crashes for long prompts / long KV-cache generation.
Growth is amortized by rounding up to the next power of two.
NOTE: We avoid register_buffer() here; we simply overwrite the existing buffers.
"""
# existing cache length
cur = self.cos.size(1)
if needed_seq_len <= cur:
cur_len = self.cos.size(1)
if needed_seq_len <= cur_len:
return
# grow to next power-of-two for amortized behavior
new_len = 1
while new_len < needed_seq_len:
new_len *= 2
# Safety: mutating buffers during torch.compile tracing is unsafe.
try:
import torch._dynamo
if torch._dynamo.is_compiling():
raise RuntimeError(
f"RoPE cache too small during torch.compile (need {needed_seq_len}, have {cur_len}). "
f"Increase initial rotary_seq_len or disable compile for generation."
)
except Exception:
# torch._dynamo may not exist in older torch; ignore.
pass
# Next power-of-two >= needed_seq_len
new_len = 1 << (needed_seq_len - 1).bit_length()
head_dim = self.config.n_embd // self.config.n_head
device = self.cos.device
cos, sin = self._precompute_rotary_embeddings(
seq_len=new_len,
head_dim=head_dim,
device=device,
)
# keep dtype consistent with existing buffers
cos = cos.to(dtype=self.cos.dtype)
sin = sin.to(dtype=self.sin.dtype)
seq_len=new_len, head_dim=head_dim, device=device)
# re-register buffers (safe overwrite)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
# Preserve dtype/device invariants (precompute returns bf16 already)
cos = cos.to(dtype=self.cos.dtype, device=device)
sin = sin.to(dtype=self.sin.dtype, device=device)
# Overwrite existing registered buffers (no re-register)
self.cos = cos
self.sin = sin
self.rotary_seq_len = new_len # keep metadata consistent
@torch.no_grad()
def init_weights(self):
@ -418,16 +428,16 @@ class GPT(nn.Module):
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size()
T0 = 0 if kv_cache is None else kv_cache.get_pos()
# Ensure cache covers absolute positions [T0, T0+T)
self._ensure_rope_cache(T0 + T)
# Ensure RoPE buffers cover absolute positions [T0, T0+T)
self._ensure_rope_cache(T0 + T, device=idx.device)
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
# Now it's safe to slice
assert idx.device == self.cos.device, f"RoPE buffers and idx device mismatch: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "RoPE buffers must be bfloat16"
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
# Forward the trunk of the Transformer
x = self.transformer.wte(idx) # embed current token