diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 04ee5c5..5e99c73 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -76,7 +76,7 @@ class CausalSelfAttention(nn.Module): self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_proj = Linear(self.n_embd, self.n_embd, bias=False) - self.ve_gate_channels = 32 + self.ve_gate_channels = 12 self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None def forward(self, x, ve, cos_sin, window_size, kv_cache): @@ -91,13 +91,15 @@ class CausalSelfAttention(nn.Module): # Value residual (ResFormer): mix in value embedding with input-dependent gate per head if ve is not None: ve = ve.view(B, T, self.n_kv_head, self.head_dim) - gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 2) + gate = 3 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 3) v = v + gate.unsqueeze(-1) * ve # Apply Rotary Embeddings to queries and keys to get relative positional encoding cos, sin = cos_sin q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) q, k = norm(q), norm(k) # QK norm + q = q * 1.15 # sharper attention (split scale between Q and K), TODO think through better + k = k * 1.15 # Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere) # window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context @@ -208,7 +210,7 @@ class GPT(nn.Module): """ # Embedding and unembedding - torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0) + torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8) torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001) # Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal) @@ -219,7 +221,7 @@ class GPT(nn.Module): torch.nn.init.uniform_(block.attn.c_k.weight, -s, s) torch.nn.init.uniform_(block.attn.c_v.weight, -s, s) torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero - torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s) + torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.5, s * 0.5) # 0.5x init scale for c_fc torch.nn.init.zeros_(block.mlp.c_proj.weight) # Per-layer scalars @@ -230,10 +232,10 @@ class GPT(nn.Module): for ve in self.value_embeds.values(): torch.nn.init.uniform_(ve.weight, -s, s) - # Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral) + # Gate weights init with small positive values so gates start slightly above neutral for block in self.transformer.h: if block.attn.ve_gate is not None: - torch.nn.init.zeros_(block.attn.ve_gate.weight) + torch.nn.init.uniform_(block.attn.ve_gate.weight, 0.0, 0.02) # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head @@ -248,7 +250,7 @@ class GPT(nn.Module): for ve in self.value_embeds.values(): ve.to(dtype=COMPUTE_DTYPE) - def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): + def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None): # TODO: bump base theta more? e.g. 100K is more common more recently # autodetect the device from model embeddings if device is None: @@ -280,7 +282,7 @@ class GPT(nn.Module): assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L." # Map characters to window sizes long_window = config.sequence_len - short_window = long_window // 2 + short_window = -(-long_window // 3 // 128) * 128 # ceil to FA3 tile size (2048 -> 768) char_to_window = { "L": (long_window, 0), "S": (short_window, 0), @@ -353,7 +355,7 @@ class GPT(nn.Module): 'total': total, } - def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): + def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() @@ -373,10 +375,10 @@ class GPT(nn.Module): # Build param_groups with all required fields explicit param_groups = [ # AdamW groups (embeddings, lm_head, scalars) - dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), - dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), - dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), - dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0), + dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=(0.8, 0.96), eps=1e-10, weight_decay=0.01), + dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.001), + dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01), + dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05), dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0 ] # Muon groups (matrix params, grouped by shape for stacking) @@ -384,7 +386,7 @@ class GPT(nn.Module): group_params = [p for p in matrix_params if p.shape == shape] param_groups.append(dict( kind='muon', params=group_params, lr=matrix_lr, - momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay, + momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay, )) Factory = DistMuonAdamW if ddp else MuonAdamW @@ -416,7 +418,7 @@ class GPT(nn.Module): x = norm(x) # Forward the lm_head (compute logits) - softcap = 20 # smoothly cap the logits to the range [-softcap, softcap] + softcap = 15 # smoothly cap the logits to the range [-softcap, softcap] logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory logits = logits[..., :self.config.vocab_size] # slice to remove padding logits = logits.float() # switch to fp32 for logit softcap and loss computation diff --git a/nanochat/optim.py b/nanochat/optim.py index 42d862b..0ee2e27 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -113,7 +113,7 @@ def muon_step_fused( # Polar express X = g.bfloat16() - X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) + X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6) if g.size(-2) > g.size(-1): # Tall matrix for a, b, c in polar_express_coeffs[:ns_steps]: A = X.mT @ X diff --git a/scripts/base_train.py b/scripts/base_train.py index 4bf7959..cfbfe28 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -60,15 +60,13 @@ parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help= parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.") parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)") parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") -parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") -parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)") +parser.add_argument("--unembedding-lr", type=float, default=0.008, help="learning rate for unembedding parameters (Adam)") +parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious weight decay for the Muon optimizer (for weights)") parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") -parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding") -parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding") -parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") -parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown") -parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") +parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup") +parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown") +parser.add_argument("--final-lr-frac", type=float, default=0.05, help="final LR as fraction of initial LR") parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)") # Evaluation parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") @@ -311,7 +309,6 @@ optimizer = model.setup_optimizer( unembedding_lr=args.unembedding_lr * batch_lr_scale, embedding_lr=args.embedding_lr * batch_lr_scale, scalar_lr=args.scalar_lr * batch_lr_scale, - adam_betas=(args.adam_beta1, args.adam_beta2), # Muon hyperparameters matrix_lr=args.matrix_lr * batch_lr_scale, weight_decay=weight_decay_scaled, @@ -360,7 +357,7 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") # Learning rate schedule (linear warmup, constant, linear warmdown) def get_lr_multiplier(it): - warmup_iters = round(args.warmup_ratio * num_iterations) + warmup_iters = args.warmup_steps warmdown_iters = round(args.warmdown_ratio * num_iterations) if it < warmup_iters: return (it + 1) / warmup_iters @@ -370,15 +367,15 @@ def get_lr_multiplier(it): progress = (num_iterations - it) / warmdown_iters return progress * 1.0 + (1 - progress) * args.final_lr_frac -# Momentum scheduler for Muon optimizer (warms up to 0.95 over the first 300 steps) +# Momentum scheduler for Muon optimizer (warms up to 0.97 over the first 400 steps) def get_muon_momentum(it): - frac = min(it / 300, 1) - momentum = (1 - frac) * 0.85 + frac * 0.95 + frac = min(it / 400, 1) + momentum = (1 - frac) * 0.85 + frac * 0.97 return momentum -# Weight decay scheduler for Muon optimizer (linearly decays to zero over the course of training) +# Weight decay scheduler for Muon optimizer (cosine decay to zero over the course of training) def get_weight_decay(it): - return weight_decay_scaled * (1 - it / num_iterations) + return weight_decay_scaled * 0.5 * (1 + math.cos(math.pi * it / num_iterations)) # ----------------------------------------------------------------------------- # Training loop @@ -605,7 +602,7 @@ get_report().log(section="Base model training", data=[ "Number of training tokens": total_tokens, "Tokens : Scaling params ratio": total_batch_size * num_iterations / num_scaling_params, "DDP world size": ddp_world_size, - "warmup_ratio": args.warmup_ratio, + "warmup_steps": args.warmup_steps, "warmdown_ratio": args.warmdown_ratio, "final_lr_frac": args.final_lr_frac, },