diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 0b822e4..8023ebb 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -47,12 +47,27 @@ class Linear(nn.Linear): Replaces autocast: master weights stay fp32 for optimizer precision, but matmuls run in the activation dtype (typically bf16 from embeddings).""" def forward(self, x): - return F.linear(x, self.weight.to(dtype=x.dtype)) + w = self.weight + if w.dtype != x.dtype: + w = w.to(dtype=x.dtype) + return F.linear(x, w) + + +class EmbeddingLinear(nn.Module): + """Lightweight linear layer for lm_head without redundant dtype casting.""" + def __init__(self, in_features, out_features, bias=False, device=None, dtype=None): + super().__init__() + assert not bias + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.empty(out_features, in_features, device=device, dtype=dtype)) + def forward(self, x): + return F.linear(x, self.weight) def has_ve(layer_idx, n_layer): - """Returns True if GPT layer should have Value Embedding (alternating, last layer always included).""" - return layer_idx % 2 == (n_layer - 1) % 2 + """Returns True if GPT layer should have Value Embedding (every 3rd layer, last layer always included).""" + return layer_idx % 3 == (n_layer - 1) % 3 def apply_rotary_emb(x, cos, sin): assert x.ndim == 4 # multihead attention @@ -172,19 +187,16 @@ class GPT(nn.Module): "wte": nn.Embedding(padded_vocab_size, config.n_embd), "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), }) - self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False) + self.lm_head = EmbeddingLinear(config.n_embd, padded_vocab_size, bias=False) # Per-layer learnable scalars (inspired by modded-nanogpt) # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral) # x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled) # Separate parameters so they can have different optimizer treatment self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() - # Smear: mix previous token's embedding into current token (cheap bigram-like info) - self.smear_gate = Linear(24, 1, bias=False) - self.smear_lambda = nn.Parameter(torch.zeros(1)) # Backout: subtract cached mid-layer residual before final norm to remove low-level features self.backout_lambda = nn.Parameter(0.2 * torch.ones(1)) - # Value embeddings (ResFormer-style): alternating layers, last layer always included + # Value embeddings (ResFormer-style): every 3rd layer, last layer always included head_dim = config.n_embd // config.n_head kv_dim = config.n_kv_head * head_dim self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)}) @@ -224,19 +236,27 @@ class GPT(nn.Module): for block in self.transformer.h: torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers torch.nn.init.uniform_(block.attn.c_k.weight, -s, s) - torch.nn.init.uniform_(block.attn.c_v.weight, -s, s) - torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero + torch.nn.init.uniform_(block.attn.c_v.weight, -0.85 * s, 0.85 * s) + torch.nn.init.uniform_(block.attn.c_proj.weight, -0.008, 0.008) # small nonzero init torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4) # 0.4x init scale for c_fc torch.nn.init.zeros_(block.mlp.c_proj.weight) # Per-layer scalars - # Per-layer resid init: stronger residual at early layers, weaker at deep layers + # Per-layer resid init: exponential decay, stronger at early layers + import math n_layer = self.config.n_layer + resid_start, resid_end = 1.18, 1.06 + resid_decay = math.log(resid_start / resid_end) / max(n_layer - 1, 1) for i in range(n_layer): - self.resid_lambdas.data[i] = 1.15 - (0.10 * i / max(n_layer - 1, 1)) - # Decaying x0 init: earlier layers get more input embedding blending + self.resid_lambdas.data[i] = resid_start * math.exp(-resid_decay * i) + # x0 init: first-half only, linearly decaying, zero for deep layers + half_depth = max(1, n_layer // 2) for i in range(n_layer): - self.x0_lambdas.data[i] = 0.20 - (0.15 * i / max(n_layer - 1, 1)) + if i < half_depth: + frac = i / max(half_depth - 1, 1) + self.x0_lambdas.data[i] = 0.24 * (1.0 - frac) + 0.08 * frac + else: + self.x0_lambdas.data[i] = 0.0 # Value embeddings (init like c_v: uniform with same std) for ve in self.value_embeds.values(): @@ -257,10 +277,11 @@ class GPT(nn.Module): # because GradScaler cannot unscale fp16 gradients. if COMPUTE_DTYPE != torch.float16: self.transformer.wte.to(dtype=COMPUTE_DTYPE) + self.lm_head.to(dtype=COMPUTE_DTYPE) for ve in self.value_embeds.values(): ve.to(dtype=COMPUTE_DTYPE) - def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None): + def _precompute_rotary_embeddings(self, seq_len, head_dim, base=200000, device=None): # TODO: bump base theta more? e.g. 100K is more common more recently # autodetect the device from model embeddings if device is None: @@ -292,7 +313,7 @@ class GPT(nn.Module): assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L." # Map characters to window sizes long_window = config.sequence_len - short_window = -(-long_window // 4 // 128) * 128 # ceil to FA3 tile size (2048 -> 768) + short_window = max(256, -(-long_window // 8 // 128) * 128) # ceil to FA3 tile size (2048 -> 256) char_to_window = { "L": (long_window, 0), "S": (short_window, 0), @@ -326,7 +347,7 @@ class GPT(nn.Module): value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + self.resid_lambdas.numel() + self.x0_lambdas.numel() + - self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()) + self.backout_lambda.numel()) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window attn_flops = 0 @@ -354,7 +375,7 @@ class GPT(nn.Module): value_embeds = sum(p.numel() for p in self.value_embeds.parameters()) lm_head = sum(p.numel() for p in self.lm_head.parameters()) transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters()) - scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() + scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.backout_lambda.numel() total = wte + value_embeds + lm_head + transformer_matrices + scalars assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch" return { @@ -377,8 +398,8 @@ class GPT(nn.Module): lm_head_params = list(self.lm_head.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] - smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params) + backout_params = [self.backout_lambda] + assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(backout_params) # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 @@ -392,7 +413,7 @@ class GPT(nn.Module): dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01), dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05), dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0 - dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0), + dict(kind='adamw', params=backout_params, lr=0.15, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0), ] # Muon groups (matrix params, grouped by shape for stacking) for shape in sorted({p.shape for p in matrix_params}): @@ -424,25 +445,6 @@ class GPT(nn.Module): x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path) x = norm(x) - # Smear: mix previous token's embedding into current position (cheap bigram info) - if kv_cache is None: - # Training / naive generate: full sequence available, use fast slice - assert T > 1, "Training forward pass should have T > 1" - gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24])) - x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1) - else: - # KV cache inference: read prev embedding from cache, store current for next step - x_pre_smear = kv_cache.prev_embedding - kv_cache.prev_embedding = x[:, -1:, :] - if T > 1: - # Prefill: apply smear to positions 1+, same as training - gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24])) - x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1) - elif x_pre_smear is not None: - # Decode: single token, use cached prev embedding - gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24])) - x = x + gate * x_pre_smear - # Forward the trunk of the Transformer x0 = x # save initial normalized embedding for x0 residual n_layer = self.config.n_layer diff --git a/nanochat/optim.py b/nanochat/optim.py index 56e85e1..b845576 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -255,9 +255,12 @@ class MuonAdamW(torch.optim.Optimizer): second_momentum_buffer = state["second_momentum_buffer"] red_dim = -1 if shape[-2] >= shape[-1] else -2 - # Stack grads and params (NOTE: this assumes all params have the same shape) - stacked_grads = torch.stack([p.grad for p in params]) - stacked_params = torch.stack(params) + # Stack grads and params using pre-allocated buffers (NOTE: this assumes all params have the same shape) + stacked_grads = torch.empty(num_params, *shape, dtype=dtype, device=device) + stacked_params = torch.empty(num_params, *shape, dtype=dtype, device=device) + for i, param in enumerate(params): + stacked_grads[i].copy_(param.grad) + stacked_params[i].copy_(param) # Fill all the 0-D tensors with current values self._muon_momentum_t.fill_(group["momentum"]) @@ -280,7 +283,8 @@ class MuonAdamW(torch.optim.Optimizer): ) # Copy back to original params - torch._foreach_copy_(params, list(stacked_params.unbind(0))) + for i, param in enumerate(params): + param.copy_(stacked_params[i]) @torch.no_grad() def step(self): diff --git a/scripts/base_train.py b/scripts/base_train.py index a161c47..0e78fb4 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -65,8 +65,8 @@ parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious w parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup") -parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown") -parser.add_argument("--final-lr-frac", type=float, default=0.05, help="final LR as fraction of initial LR") +parser.add_argument("--warmdown-ratio", type=float, default=0.58, help="ratio of iterations for LR warmdown") +parser.add_argument("--final-lr-frac", type=float, default=0.10, help="final LR as fraction of initial LR") parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)") # Evaluation parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") @@ -368,7 +368,7 @@ def get_lr_multiplier(it): progress = (num_iterations - it) / warmdown_iters return progress * 1.0 + (1 - progress) * args.final_lr_frac -# Momentum scheduler for Muon optimizer (warms up to 0.97, warms down to 0.90 during LR warmdown) +# Momentum scheduler for Muon optimizer (warms up to 0.97, warms down to 0.92 during LR warmdown) def get_muon_momentum(it): warmdown_iters = round(args.warmdown_ratio * num_iterations) warmdown_start = num_iterations - warmdown_iters @@ -377,7 +377,7 @@ def get_muon_momentum(it): return (1 - frac) * 0.85 + frac * 0.97 elif it >= warmdown_start: progress = (it - warmdown_start) / warmdown_iters - return 0.97 * (1 - progress) + 0.90 * progress + return 0.97 * (1 - progress) + 0.92 * progress else: return 0.97