This commit is contained in:
Dipesh Babu 2026-02-16 15:22:20 +00:00 committed by GitHub
commit ee89623f4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 59 additions and 13 deletions

View File

@ -12,7 +12,6 @@ Notable features:
- Flash Attention 3 integration
"""
from functools import partial
from dataclasses import dataclass
import torch
@ -22,6 +21,7 @@ import torch.nn.functional as F
from nanochat.common import get_dist_info, print0
from nanochat.optim import MuonAdamW, DistMuonAdamW
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
from nanochat.flash_attention import flash_attn
@ -185,6 +185,37 @@ class GPT(nn.Module):
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)
def _ensure_rope_cache(self, needed_seq_len: int):
"""
Ensure rotary embedding cache (cos/sin) is long enough.
We grow the cache lazily to avoid evaluation crashes on long prompts.
"""
# existing cache length
cur = self.cos.size(1)
if needed_seq_len <= cur:
return
# grow to next power-of-two for amortized behavior
new_len = 1
while new_len < needed_seq_len:
new_len *= 2
head_dim = self.config.n_embd // self.config.n_head
device = self.cos.device
cos, sin = self._precompute_rotary_embeddings(
seq_len=new_len,
head_dim=head_dim,
device=device,
)
# keep dtype consistent with existing buffers
cos = cos.to(dtype=self.cos.dtype)
sin = sin.to(dtype=self.sin.dtype)
# re-register buffers (safe overwrite)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
@torch.no_grad()
def init_weights(self):
"""
@ -388,12 +419,14 @@ class GPT(nn.Module):
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
T0 = 0 if kv_cache is None else kv_cache.get_pos()
# Ensure cache covers absolute positions [T0, T0+T)
self._ensure_rope_cache(T0 + T)
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
# Forward the trunk of the Transformer

View File

@ -346,15 +346,28 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# Learning rate schedule (linear warmup, constant, linear warmdown)
def get_lr_multiplier(it):
warmup_iters = round(args.warmup_ratio * num_iterations)
warmdown_iters = round(args.warmdown_ratio * num_iterations)
if it < warmup_iters:
# Note: optimizer steps run for it in [0, num_iterations-1]
warmup_iters = int(round(args.warmup_ratio * num_iterations))
warmdown_iters = int(round(args.warmdown_ratio * num_iterations))
# Warmup (avoid division by zero when warmup_iters == 0)
if warmup_iters > 0 and it < warmup_iters:
return (it + 1) / warmup_iters
elif it <= num_iterations - warmdown_iters:
return 1.0
else:
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * args.final_lr_frac
# Warmdown should cover the last `warmdown_iters` optimizer steps:
# it in [num_iterations - warmdown_iters, num_iterations - 1]
if warmdown_iters > 0:
warmdown_start = num_iterations - warmdown_iters
# Ensure warmdown doesn't start before warmup ends (prevents overlap weirdness)
warmdown_start = max(warmdown_start, warmup_iters)
if it >= warmdown_start:
# progress: 1.0 at warmdown_start, 0.0 at last optimizer step (num_iterations - 1)
span = max(1, (num_iterations - 1) - warmdown_start) # denom >= 1
progress = (num_iterations - 1 - it) / span
return progress * 1.0 + (1.0 - progress) * args.final_lr_frac
return 1.0
# Momentum scheduler for Muon optimizer (warms up to 0.95 over the first 300 steps)
def get_muon_momentum(it):