L3 generalizes token embeddings by placing per-token lookup tables inside
the decoder stack. Unlike MoE, routing is static (determined by token ID),
eliminating router training and load-balancing losses.
Implementation:
- nanochat/l3.py: LZW allocation algorithm and L3Layer module with
vectorized gather+pad+mask forward pass, tied/untied KV support
- GPT integration: L3 layers sit between decoder blocks, applied
residually (x = x + l3_layer(x, token_ids))
- CLI: --l3-after-layers, --l3-n-emb, --l3-d-up, --l3-k-max flags
with LZW precomputation from training data sample
- 17 tests covering allocation, layer, and GPT integration
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Accept upstream's architectural changes wholesale:
- argparse replaces configurator.py across all scripts
- Unified MuonAdamW optimizer replaces separate AdamW + Muon
- Sliding window attention (SSSL pattern) + Flash Attention 3
- Value embeddings (ResFormer-style) with per-layer gating
- Per-layer learnable scalars (resid_lambdas, x0_lambdas)
- FP8 training support with Float8Linear
- Scaling laws (Power Lines batch sizing, T_epoch weight decay)
- Checkpoint resumption with dataloader state
- BOS-aligned bestfit-pad packing for SFT
- ChatCORE evaluation metric
- Consolidated base_loss.py into base_eval.py
- Removed mid_train.py (pipeline simplified)
Drops our MoE and tie_embeddings implementations in favor of
upstream's cleaner architecture. These can be re-added later
on top of the new codebase if needed.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Implement weight tying between token embeddings and lm_head to reduce
parameter count. When enabled, logits are scaled by 1/√d_model, lm_head
zeroing is skipped, and optimizer groups are deduplicated. Param counting
uses unique parameters while Chinchilla ratio calculation adds back the
would-be lm_head size for comparability.
Also adds boolean flag parsing (--flag without =value) to the configurator,
an auto-derived log_every interval, and minor shell script fixes.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Store quantized input/weight and their inverse scales in _Float8Matmul ctx to avoid re-quantization in backward and reduce saved-activation memory without changing numerics.