nanochat/dev/LOG.md

10 KiB
Raw Blame History

Experiment Log

A running summary documenting some experiments and findings. Started ~Jan 7 2026.


2026-01-11: Sliding Window Attention

Added configurable sliding window attention, inspired by GPT-3's alternating short/long pattern.

Pattern string configuration:

  • New --window_pattern CLI arg and GPTConfig.window_pattern field
  • Pattern is tiled across layers (e.g., SSSL for 20 layers → SSSLSSSLSSSLSSSLSSSL)
  • Final layer always forced to L (full context) regardless of pattern
  • Short window = sequence_len // 2
  • Long window = sequence_len (full context)
  • All previous models so far have been simply L and checkpoint loading is modified accordingly to fill in this param for old models, see _patch_missing_config_keys

Quick experiments showed SSSL (every 4th layer is long) works well - provides a good balance between compute savings and model quality. This is now the default.


2026-01-11: Flash Attention 3 Integration

Replaced PyTorch's scaled_dot_product_attention (FA2) with Flash Attention 3 for training and inference.

Changes Made

1. FA3 via kernels package

  • Official FA3 is "beta" and requires building from source (painful)
  • Using kernels package from HuggingFace Hub: get_kernel('varunneal/flash-attention-3')
  • Loads pre-built wheels, works out of the box on H100

2. Simplified attention code

  • FA3 uses (B, T, H, D) layout matching our projection output directly - no transpose needed
  • Training: flash_attn.flash_attn_func(q, k, v, causal=True)
  • Inference: flash_attn.flash_attn_with_kvcache() handles all cache cases in one call
  • Removed 3 separate FA2 code paths (training, single-token, chunk inference)
  • GQA handled automatically when n_kv_heads < n_heads

3. Rewrote KVCache for FA3

  • Old format: (num_layers, 2, B, H, T, D) combined tensor
  • New format: separate k_cache and v_cache of shape (num_layers, B, T, H, D)
  • FA3 updates cache in-place during flash_attn_with_kvcache
  • Position tracked via cache_seqlens tensor (int32, per batch element)
  • Simpler API: get_layer_cache(), advance(), reset(), prefill()

Results

  • ~9% improvement in tok/sec during training out of the box
  • Benchmarks showed FA3 is 2x faster than FA2 at realistic training sizes (batch=32, seq=2048)
  • FA3 supports sliding window via window_size=(left, 0), which is huge and expected to give further improvements. This is ready to tune but keeping full context for now.

2026-01-11: Per-Layer Residual Scalars (x0 & resid lambdas)

Cherry-picked an idea from modded-nanogpt around learnable per-layer residual connections.

Changes Made

1. x0_lambdas (x0 residual connections)

  • Save initial normalized embedding as x0 after norm(wte(idx))
  • At each layer, blend x0 back in: x = resid_lambdas[i] * x + x0_lambdas[i] * x0
  • Zero-initialized, so disabled at start; model learns which layers benefit from the shortcut
  • Provides direct path from embedding to deep layers, helps preserve token information

2. resid_lambdas (residual stream scaling)

  • Per-layer multiplicative scaling of the residual stream
  • Initialized to 1.0 (neutral, standard transformer behavior)
  • Allows model to learn to amplify/dampen residual at each layer

3. DistAdamW small parameter handling

  • Added support for parameters with < 1024 elements (like the scalar lambdas)
  • Small params use all_reduce instead of reduce_scatter/all_gather
  • Fixes crash when param shape isn't divisible by world_size

Key Finding: Different LR Sensitivity

The two scalar types need very different learning rates:

  • x0_lambdas (additive): Can use normal LR (~0.5). Adding a fraction of x0 is forgiving.
  • resid_lambdas (multiplicative): Needs ~100x smaller LR (~0.005). Multiplying the residual compounds through layers.

Implementation: resid_params gets scalar_lr * 0.01, x0_params gets full scalar_lr.

Experiment Results

Swept --scalar_lr (controlling x0_lambdas) at multiple depths:

Depth Baseline (disabled) Best scalar_lr Best val_bpb Δ bpb
d8 1.0885 0.20 1.0782 -0.0103
d12 0.9770 0.60 0.9693 -0.0077
d16 0.9059 0.20 0.9002 -0.0057
d20 0.8565 0.10 0.8526 -0.0039

Observations:

  • Consistent improvement across all model sizes
  • Optimal LR varies by depth; default of 0.5 is reasonable, but 0.6 is better for d12
  • Adding resid_lambdas (with 0.01x LR) gives small additional improvement over x0 alone

Meta Device Footgun

Important lesson: __init__ runs in meta device context, so any tensor values set there are fake. Must initialize actual values in init_weights(). Added docstring warning to __init__.

Summary

Added --scalar_lr (default 0.5) controlling learnable per-layer scalars. The formula x = resid_lambdas[i] * x + x0_lambdas[i] * x0 gives the model control over residual scaling and direct shortcuts to the initial embedding. Solid improvement with essentially no compute overhead.


2026-01-10: Muon Optimizer Upgrades & Cautious Weight Decay

Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon implementation. Decided against using NorMuon directly due to hard-coded architecture assumptions (expects 32 params split 10 attn + 22 mlp), parameter labeling requirements, and complexity.

Changes Made

1. Polar Express Orthogonalization

  • Replaced Newton-Schulz iteration with "Polar Express Sign Method" from arxiv.org/pdf/2505.16932
  • Uses 5 different coefficient tuples (one per iteration) instead of fixed coefficients
  • Both methods kept in code for easy comparison (zeropower_via_polar_express vs zeropower_via_newtonschulz5)
  • Result: No dramatic/noticeable difference in training, but keeping the new Polar Express as default.

2. Variance Reduction (NorMuon-style)

  • Added low-rank variance estimator similar to Adafactor (arxiv.org/pdf/2510.05491)
  • Maintains second_momentum_buffer with shape [rows, 1] or [1, cols] (whichever is smaller)
  • Normalizes updates based on running per-row/col variance estimate (beta2=0.95)
  • Memory overhead: ~1/max(rows, cols) per param, negligible
  • Result: Led to a very small improvement, kept and enabled by default.

3. Cautious Weight Decay

  • Only decays weights where update * weight >= 0 (same sign) from arxiv.org/abs/2411.16085
  • Standard WD always pulls toward zero; cautious WD skips decay when gradient is pushing weight away from zero
  • Implementation note: Had to inline the logic rather than use a separate @torch.compile function. Passing changing float values (like weight_decay during scheduling) as function arguments triggers recompilation. Reading from group["weight_decay"] inside the step avoids this.
  • Result: Solid improvements, especially the cautious version was better than standard wd.
  • Now defaults to ON for Muon via the weight_decay param. AdamW still has no weight decay and is hardcoded to 0 weight decay, might try to re-tune this later.

4. Weight decay schedule

  • Added a linear schedule to weight decay that is default on from 1.0 to 0.0 (i.e. start with max weight decay in the beginning of training, them ramp to 0 by the end). Worked better than a static setting in experiments. (modded-nanogpt has the same schedule but it is imlpemented in a more confusing way by multiplying twice by the learning rate, which is already wired up to a decay schedule).

Weight Decay Scaling Experiments

Swept weight decay values at d8, d12, d16, d20 to find optimal values and scaling law.

Optimal Values Found:

Depth Width (channels) Optimal WD
d8 512 ~0.40
d12 768 ~0.22
d16 1024 ~0.10
d20 1280 ~0.08

Scaling Law:

  • Fit power law: WD = k / channels^α in log-log space
  • Found α ≈ 1.97 (approximately 2), meaning WD ∝ 1/width²

Practical Formula:

WD_target = WD_reference × (d_reference / d_target)²

Example: If d12 optimal is 0.22, then d20 optimal ≈ 0.22 × (12/20)² ≈ 0.08

Reference: Moonlight paper uses fixed WD=0.1 for their 15B MoE model. Our experiments indicated a scaling law where the optimal WD changed with depth, so we go along with the empirical scaling law.

Summary

Muon was changed to use Polar Express, added Adafactor-style variance reduction, and cautious weight decay with schedule that ramps linearly to zero. All of these changes follow modded-nanogpt repo, but all of them were also validated piece by piece to yield improvements in nanochat with the exception of the Polar Express change which was in the noise. This is default on and configurable with --weight_decay, using simply 0.2 and ∝ 1/width² scaling. The kwarg --weight_decay is therefore changing as of this change. It used to configure AdamW via standard weight decay and now it becomes exclusively used in Muon (AdamW is hardcoded to 0.0), and it is scaled based on depth.


2026-01-08: exp_grad_clip - Gradient Clipping

Hypothesis: Gradient clipping may be unnecessary overhead. Tested L2 norm clipping at various thresholds (0.25, 0.5, 1.0, 2.0) and elementwise clipping.

Results:

  • No benefit at any scale tested (d12, d20)
  • All variants within noise (~0.9827 val_bpb)
  • Grad norm never exceeds 1.0 naturally, so clipping is always inactive
  • Clipping adds ~2% time overhead from the all-reduce

Bug Found: Original implementation clipped local gradients before sync. Since this codebase doesn't use DDP (gradient sync is in the optimizers), each rank was clipping based on its own local norm. Fixed on the branch with proper distributed all-reduce.

Observartion: modded-nanogpt does not appear to clip either right now.

Summary: Deleted all grad-clip code paths. The code naturally produces well-behaved gradients. This improves a bit of MFU because we don't have to calculate and sync grad norms.