mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-22 04:43:27 +00:00
fix RoPE cache overflow with kv-cache by growing rope buffers
This commit is contained in:
parent
2dffdc8cf6
commit
3cb530cf97
|
|
@ -22,6 +22,8 @@ import torch.nn.functional as F
|
|||
from nanochat.common import get_dist_info, print0
|
||||
from nanochat.optim import MuonAdamW, DistMuonAdamW
|
||||
|
||||
from typing import Optional
|
||||
|
||||
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
||||
from nanochat.flash_attention import flash_attn
|
||||
|
||||
|
|
@ -185,6 +187,37 @@ class GPT(nn.Module):
|
|||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
|
||||
def _ensure_rope_cache(self, needed_seq_len: int):
|
||||
"""
|
||||
Ensure rotary embedding cache (cos/sin) is long enough.
|
||||
We grow the cache lazily to avoid evaluation crashes on long prompts.
|
||||
"""
|
||||
# existing cache length
|
||||
cur = self.cos.size(1)
|
||||
if needed_seq_len <= cur:
|
||||
return
|
||||
|
||||
# grow to next power-of-two for amortized behavior
|
||||
new_len = 1
|
||||
while new_len < needed_seq_len:
|
||||
new_len *= 2
|
||||
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
device = self.cos.device
|
||||
|
||||
cos, sin = self._precompute_rotary_embeddings(
|
||||
seq_len=new_len,
|
||||
head_dim=head_dim,
|
||||
device=device,
|
||||
)
|
||||
# keep dtype consistent with existing buffers
|
||||
cos = cos.to(dtype=self.cos.dtype)
|
||||
sin = sin.to(dtype=self.sin.dtype)
|
||||
|
||||
# re-register buffers (safe overwrite)
|
||||
self.register_buffer("cos", cos, persistent=False)
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def init_weights(self):
|
||||
"""
|
||||
|
|
@ -388,12 +421,14 @@ class GPT(nn.Module):
|
|||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||
B, T = idx.size()
|
||||
|
||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||
|
||||
# Ensure cache covers absolute positions [T0, T0+T)
|
||||
self._ensure_rope_cache(T0 + T)
|
||||
|
||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
||||
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
||||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||
|
||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user