Accept upstream's architectural changes wholesale:
- argparse replaces configurator.py across all scripts
- Unified MuonAdamW optimizer replaces separate AdamW + Muon
- Sliding window attention (SSSL pattern) + Flash Attention 3
- Value embeddings (ResFormer-style) with per-layer gating
- Per-layer learnable scalars (resid_lambdas, x0_lambdas)
- FP8 training support with Float8Linear
- Scaling laws (Power Lines batch sizing, T_epoch weight decay)
- Checkpoint resumption with dataloader state
- BOS-aligned bestfit-pad packing for SFT
- ChatCORE evaluation metric
- Consolidated base_loss.py into base_eval.py
- Removed mid_train.py (pipeline simplified)
Drops our MoE and tie_embeddings implementations in favor of
upstream's cleaner architecture. These can be re-added later
on top of the new codebase if needed.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Implement weight tying between token embeddings and lm_head to reduce
parameter count. When enabled, logits are scaled by 1/√d_model, lm_head
zeroing is skipped, and optimizer groups are deduplicated. Param counting
uses unique parameters while Chinchilla ratio calculation adds back the
would-be lm_head size for comparability.
Also adds boolean flag parsing (--flag without =value) to the configurator,
an auto-derived log_every interval, and minor shell script fixes.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add --run=$WANDB_RUN to base_train, mid_train, and chat_sft calls
- Ensures wandb logging works when WANDB_RUN environment variable is set
- Matches the behavior in speedrun.sh
Co-authored-by: svlandeg <svlandeg@github.com>
The new DataLoader ensures that every token sequence in train/val batches has a BOS token
at the beginning. Therefore, no token streams start abruptly in the middle of a document,
which could be confusing for the model. Note that this changes the loss scale because there
are fewer confusing tokens in the train/val batches. The main downside is that we now waste
about 35% of tokens due to cropping. This is ok because we have a lot of data. See dev/LOG.md
entry for this change for a lot more information.
- Introduced `MOE_DEBUG_INTERVAL` parameter in `runmps.sh` to control debug logging frequency during training.
- Enhanced `base_train.py` to log gradients of routed and shared weights at specified intervals, aiding in monitoring model performance.
- Updated `gpt.py` to adjust router bias calculations, improving load balancing among experts.
- Added unit tests in `test_moe.py` to validate the behavior of the MoE implementation and ensure correctness of gradient calculations.
- Introduced parameters for Mixture of Experts (MoE) in `runmps.sh`, `base_train.py`, and `gpt.py`, allowing for dynamic configuration of experts during training.
- Enhanced `gpt.py` with new classes `MoEFeedForward` and `ExpertFFN` to implement MoE functionality in the model architecture.
- Updated `configurator.py` to handle type conversions for new MoE parameters.
- Improved logging in `base_train.py` to include MoE-related metrics and configurations during training.
- Added assertions and derived defaults for MoE parameters to ensure valid configurations.
- Implemented methods to estimate and log FLOPs for both dense and MoE active configurations during training.
- Enhanced gradient handling in `muon.py` to accommodate potential absence of gradients for unused experts.
- Added model tagging functionality to `runmps.sh`, allowing for dynamic model tagging based on the W&B run name.
- Updated `base_train.py`, `mid_train.py`, and `chat_sft.py` to utilize model tags for checkpoint management.
- Enhanced `base_eval.py` to accept model tags for loading models during evaluation.
- Improved handling of model tags to ensure proper checkpoint directory naming and logging.
- Introduced `kv_head_mult` to control the number of query heads sharing a key/value head in `base_train.py`, `mid_train.py`, and `runmps.sh`.
- Updated logging to include global token per second metrics during training.
- Added assertions to ensure `kv_head_mult` is valid and properly integrated into model calculations.