mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 13:15:21 +00:00
fix: add 2x buffer to RoPE cache growth and update stale comment
Addresses review feedback: grow to max(needed, 2x current) instead of exact size, so the cache doesn't regrow on every call. Also updates the __init__ comment that said dynamic growth was future work. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
dcc7066744
commit
d097b1c958
|
|
@ -184,7 +184,7 @@ class GPT(nn.Module):
|
|||
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
||||
# In the future we can dynamically grow the cache, for now it's fine.
|
||||
# The cache grows dynamically in forward() with a 2x buffer if eval prompts exceed this.
|
||||
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
|
||||
head_dim = config.n_embd // config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
|
|
@ -400,7 +400,7 @@ class GPT(nn.Module):
|
|||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||
# Dynamically grow rotary cache if needed (e.g. few-shot eval prompts exceed training seq_len * 10)
|
||||
if T0 + T > self.cos.size(1):
|
||||
self.rotary_seq_len = T0 + T
|
||||
self.rotary_seq_len = max(T0 + T, self.cos.size(1) * 2) # 2x buffer to avoid regrowing every call
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
self.cos, self.sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user