mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 21:25:21 +00:00
dynamically grow RoPE cache for long eval prompts
When training with small --max-seq-len (e.g. 256), the static 10x RoPE cache (2560 tokens) is too small for few-shot evaluation prompts that can exceed this limit. Instead of asserting, grow the cache on demand. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
1076f97059
commit
dcc7066744
|
|
@ -396,12 +396,15 @@ class GPT(nn.Module):
|
|||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||
B, T = idx.size()
|
||||
|
||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||
assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}"
|
||||
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
||||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||
# Dynamically grow rotary cache if needed (e.g. few-shot eval prompts exceed training seq_len * 10)
|
||||
if T0 + T > self.cos.size(1):
|
||||
self.rotary_seq_len = T0 + T
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
self.cos, self.sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||
assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}"
|
||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user