mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-03 22:25:27 +00:00
Merge b661d41ffd into c7ba252142
This commit is contained in:
commit
848248a07d
|
|
@ -12,7 +12,6 @@ Notable features:
|
|||
- Flash Attention 3 integration
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
|
@ -22,6 +21,7 @@ import torch.nn.functional as F
|
|||
from nanochat.common import get_dist_info, print0
|
||||
from nanochat.optim import MuonAdamW, DistMuonAdamW
|
||||
|
||||
|
||||
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
||||
from nanochat.flash_attention import flash_attn
|
||||
|
||||
|
|
@ -176,15 +176,61 @@ class GPT(nn.Module):
|
|||
kv_dim = config.n_kv_head * head_dim
|
||||
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
||||
# In the future we can dynamically grow the cache, for now it's fine.
|
||||
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
|
||||
# Precompute a reasonably large RoPE cache up front (cheap relative to model weights).
|
||||
# The cache may also grow lazily in forward() if generation exceeds this length.
|
||||
self.rotary_seq_len = config.sequence_len * 10
|
||||
# Bound lazy growth to avoid unbounded memory usage during very long generation runs.
|
||||
self.max_rotary_seq_len = max(self.rotary_seq_len, config.sequence_len * 64)
|
||||
|
||||
head_dim = config.n_embd // config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
|
||||
def _ensure_rope_cache(self, needed_seq_len: int):
|
||||
"""
|
||||
Ensure rotary embedding cache (cos/sin) is long enough for absolute positions [0, needed_seq_len).
|
||||
|
||||
We grow lazily to avoid crashes for long prompts / long KV-cache generation.
|
||||
Growth is amortized by rounding up to the next power of two.
|
||||
|
||||
Growth is bounded by self.max_rotary_seq_len to avoid unbounded memory usage.
|
||||
"""
|
||||
cur_len = self.cos.size(1)
|
||||
if needed_seq_len <= cur_len:
|
||||
return
|
||||
|
||||
if needed_seq_len > self.max_rotary_seq_len:
|
||||
raise RuntimeError(
|
||||
f"RoPE cache request exceeds max_rotary_seq_len: need {needed_seq_len}, "
|
||||
f"have {cur_len}, cap {self.max_rotary_seq_len}. "
|
||||
"Increase max_rotary_seq_len for longer-context generation."
|
||||
)
|
||||
|
||||
# Safety: mutating buffers during torch.compile tracing is unsafe.
|
||||
import torch._dynamo
|
||||
if torch._dynamo.is_compiling():
|
||||
raise RuntimeError(
|
||||
f"RoPE cache too small during torch.compile (need {needed_seq_len}, have {cur_len}). "
|
||||
"Increase initial rotary_seq_len/max_rotary_seq_len or avoid compiled generation."
|
||||
)
|
||||
|
||||
# Next power-of-two >= needed_seq_len (amortized growth), bounded by cap
|
||||
new_len = min(self.max_rotary_seq_len, 1 << (needed_seq_len - 1).bit_length())
|
||||
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
device = self.cos.device
|
||||
cos, sin = self._precompute_rotary_embeddings(seq_len=new_len, head_dim=head_dim, device=device)
|
||||
|
||||
# Preserve dtype/device invariants (precompute already returns bf16, but keep explicit)
|
||||
cos = cos.to(dtype=self.cos.dtype, device=device)
|
||||
sin = sin.to(dtype=self.sin.dtype, device=device)
|
||||
|
||||
# Overwrite existing registered buffers (persistent=False remains from initial registration)
|
||||
self.cos = cos
|
||||
self.sin = sin
|
||||
self.rotary_seq_len = new_len
|
||||
|
||||
@torch.no_grad()
|
||||
def init_weights(self):
|
||||
"""
|
||||
|
|
@ -387,14 +433,16 @@ class GPT(nn.Module):
|
|||
|
||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||
B, T = idx.size()
|
||||
|
||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||
|
||||
# Ensure RoPE buffers cover absolute positions [T0, T0+T)
|
||||
self._ensure_rope_cache(T0 + T)
|
||||
|
||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
||||
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
||||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||
|
||||
# If kv cache exists, offset RoPE by current absolute position
|
||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
x = self.transformer.wte(idx) # embed current token
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user