Implement weight tying between token embeddings and lm_head to reduce
parameter count. When enabled, logits are scaled by 1/√d_model, lm_head
zeroing is skipped, and optimizer groups are deduplicated. Param counting
uses unique parameters while Chinchilla ratio calculation adds back the
would-be lm_head size for comparability.
Also adds boolean flag parsing (--flag without =value) to the configurator,
an auto-derived log_every interval, and minor shell script fixes.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Introduced `MOE_DEBUG_INTERVAL` parameter in `runmps.sh` to control debug logging frequency during training.
- Enhanced `base_train.py` to log gradients of routed and shared weights at specified intervals, aiding in monitoring model performance.
- Updated `gpt.py` to adjust router bias calculations, improving load balancing among experts.
- Added unit tests in `test_moe.py` to validate the behavior of the MoE implementation and ensure correctness of gradient calculations.
- Introduced parameters for Mixture of Experts (MoE) in `runmps.sh`, `base_train.py`, and `gpt.py`, allowing for dynamic configuration of experts during training.
- Enhanced `gpt.py` with new classes `MoEFeedForward` and `ExpertFFN` to implement MoE functionality in the model architecture.
- Updated `configurator.py` to handle type conversions for new MoE parameters.
- Improved logging in `base_train.py` to include MoE-related metrics and configurations during training.
- Added assertions and derived defaults for MoE parameters to ensure valid configurations.
- Implemented methods to estimate and log FLOPs for both dense and MoE active configurations during training.
- Enhanced gradient handling in `muon.py` to accommodate potential absence of gradients for unused experts.
- Added model tagging functionality to `runmps.sh`, allowing for dynamic model tagging based on the W&B run name.
- Updated `base_train.py`, `mid_train.py`, and `chat_sft.py` to utilize model tags for checkpoint management.
- Enhanced `base_eval.py` to accept model tags for loading models during evaluation.
- Improved handling of model tags to ensure proper checkpoint directory naming and logging.
- Introduced `kv_head_mult` to control the number of query heads sharing a key/value head in `base_train.py`, `mid_train.py`, and `runmps.sh`.
- Updated logging to include global token per second metrics during training.
- Added assertions to ensure `kv_head_mult` is valid and properly integrated into model calculations.
- Added `dev/runmps_evals.sh` for evaluating checkpoints and logging results to W&B.
- Introduced `dev/runmps.sh` for orchestrating training stages with W&B support.
- Updated `.gitignore` to include `wandb/` and `.runmps_wandb_ids`.
- Changed permissions for `dev/runcpu.sh` and added executable flag.
- Enhanced existing scripts to log metrics to W&B during training and evaluation processes.