This commit is contained in:
Matt Van Horn 2026-03-26 21:18:16 +01:00 committed by GitHub
commit 1dd1c0660e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -191,7 +191,7 @@ class GPT(nn.Module):
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
# In the future we can dynamically grow the cache, for now it's fine.
# The cache grows dynamically in forward() with a 2x buffer if eval prompts exceed this.
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
head_dim = config.n_embd // config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
@ -411,12 +411,15 @@ class GPT(nn.Module):
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}"
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
T0 = 0 if kv_cache is None else kv_cache.get_pos()
# Dynamically grow rotary cache if needed (e.g. few-shot eval prompts exceed training seq_len * 10)
if T0 + T > self.cos.size(1):
self.rotary_seq_len = max(T0 + T, self.cos.size(1) * 2) # 2x buffer to avoid regrowing every call
head_dim = self.config.n_embd // self.config.n_head
self.cos, self.sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}"
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
# Embed the tokens