Merge remote-tracking branch 'upstream/master' into fix-batch-size-assertion

This commit is contained in:
suraj-self 2026-02-21 08:30:41 +05:30
commit d489a1fa22
8 changed files with 135 additions and 34 deletions

View File

@ -4,6 +4,95 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
---
## 2026-02-19: Mixture of Experts (negative)
Implemented a DeepSeekV3-style Mixture of Experts layer as a drop-in replacement for the dense MLP. The MoE branch works and improves per-step validation loss, but is not a net improvement on wall clock time due to MoE overhead (at least for our scale of interest of approx GPT-2 capability).
### Implementation
Follows DeepSeekV3 and using torchtitan as reference:
- **8 routed experts, top-2 routing** with sigmoid gating (not softmax)
- **1 shared expert** (dense MLP processing all tokens, following DeepSeekV3)
- **Auxiliary-loss-free load balancing** (DeepSeekV3's expert bias nudging)
- **Iso-FLOP sizing**: `expert_hidden_dim = round(4 * dim / (top_k + num_shared) / 128) * 128`, so active FLOPs per token match the dense MLP
- **`torch._grouped_mm`** for dispatching tokens to experts in a single kernel (instead of a Python for-loop)
- **3D expert weight tensors** `(num_experts, hidden, dim)` — Muon's Polar Express operates on the last two dims, so each expert is independently orthogonalized
- **Active parameter counting** for scaling laws (only `top_k + shared` experts, not all 8)
### What was easy
- The core MoE forward pass: router, sort tokens by expert, grouped matmul, scatter back. Conceptually clean.
- Shared expert: just an `nn.Linear` MLP that runs on all tokens alongside the routed path.
- 3D expert params + Muon: only required fixing `second_momentum_buffer` shape to preserve leading dims.
- Load balancing: DeepSeekV3's bias nudging is simple and effective (~10 lines).
### What was hard / ugly
- **`torch._grouped_mm` quirks**: requires bf16 (not fp32), column-major right operand, int32 cumulative offsets. The API is undocumented and only discoverable by trial and error.
- **Token count padding**: torchtitan pads each expert's token count to alignment multiples (8 for bf16) for better grouped_mm throughput. We implemented this with both a pure PyTorch approach and a copy of torchtitan's Triton kernel. Both compiled cleanly (0 graph breaks), but with ~65K tokens across 8 experts, each expert already gets ~8K tokens which is well-aligned. The padding overhead (gather/scatter) actually regressed MFU from 35% to 33%. Reverted.
- **FP8 + MoE**: `torch._grouped_mm` does NOT support FP8. There's a separate `torch._scaled_grouped_mm` API that requires per-row scaling (not per-tensor like our `Float8Linear`). The backward pass for weight gradients needs per-group column-wise scaling, which torchao implements with custom Triton kernels. We investigated thoroughly (see `dev/moe_fp8.md`) but did not implement — would require either depending on `torchao.prototype` (unstable) or writing ~200 lines of custom autograd + quantization code. Partial FP8 support exists: the shared expert's `nn.Linear` layers do get converted, but the routed experts (3D `nn.Parameter`) stay in bf16.
### Results
- d18: MFU dropped from ~46% to ~35% (the grouped_mm dispatch + token sorting overhead is significant)
- Per-step improvement in validation loss does not compensate for the throughput hit
- Net negative on wall clock time
### What remains (if revisited)
- **FP8 for routed experts**: Use `torch._scaled_grouped_mm` with a custom `_Float8GroupedMatmul` autograd function, with bf16 fallback for weight gradient (avoiding the per-group column-wise Triton kernels).
What's really needed is a fused "FlashMoE" kernel that handles routing + expert dispatch + matmul in one shot (like FlashAttention did for attention), with all the needed features. This doesn't exist yet. Rawdogging MoE with current PyTorch primitives is painful — lots of sorting, gathering, scattering, and layout wrangling around the actual compute.
### Verdict
MoE is not worth the trouble for nanochat right now. The code bloat is substantial (moe.py, router, shared expert, load balancing, optimizer fixes, FP8 gaps, active param counting) and the performance is worse wall-clock at our scale of interest. The fundamental issue is that the grouped_mm dispatch overhead eats the FLOP savings from sparsity, at least at our model scales and sequence lengths.
---
## 2026-02-17: Pretraining Data: FineWeb (negative)
Tried vanilla fineweb instead of fineweb-edu dataset. Significantly, shockingly worse results:
- d26 (GPT-2): CORE 0.2602 → 0.2241
This is the fifth failed attempt to beat pure FineWeb-EDU on CORE score.
---
## 2026-02-17: Pretraining Data Mixture Experiment (negative)
Tried [hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT](https://huggingface.co/datasets/hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT), a mixture of FinePDFs, DCLM, and FineWeb-EDU. Slightly worse on both model sizes tested:
- d26 (GPT-2): CORE 0.2602 → 0.2549
- d18: CORE 0.199 → 0.192
This is the fourth failed attempt to beat pure FineWeb-EDU on CORE score.
---
## 2026-02-16: SFT Script Upgrades
Brought `chat_sft.py` up to parity with `base_train.py` and tuned settings based on SFT sweeps.
Tuning:
- **Optimizer warm-start** (`--load-optimizer=1`, default on): loads pretrained momentum buffers via new `load_optimizer_state()` in `checkpoint_manager.py`. LRs are reset to fresh SFT values after load. Loading the optimizer works slightly better but not by too much.
- **LR schedule**: replaced "constant 80%, linear to 0" with warmup/constant/warmdown matching `base_train.py` (`--warmup-ratio`, `--warmdown-ratio`, `--init-lr-frac`, `--final-lr-frac`). Similar to pretraining, warmdown ratio of 0.5 worked the best. `--init-lr-frac` changed from 1.0 slightly lower to 0.8.
- **LR tuning**: attempted to tune all the individual LRs (e.g. does SFT prefer lower LR for embeddings? etc.) but all of this produced negative results.
- **Data mixture**: MMLU epochs 1→3, GSM8K epochs 2→4 (confirmed best from sweeps). Epoch counts now configurable via `--mmlu-epochs` / `--gsm8k-epochs`. Might remove these in the future though.
Quality of life, footguns, minor fixes:
- **Hyperparameter inheritance**: SFT now inherits batch sizes and LRs from the pretrained checkpoint metadata by default (CLI overrides still work). Also saved `total_batch_size` to `base_train.py` checkpoint metadata.
- **GC management**: disabled Python GC after step 1 to avoid ~500ms pauses (manual collect every 5000 steps), same as base pretraining.
- **ChatCORE eval**: periodic eval during SFT (`--chatcore-every=200`) across all 6 tasks, logged to wandb.
- **MFU**: uses `get_peak_flops()` for actual GPU instead of hardcoded H100 value.
- Removed `--dry-run` and `--dtype` flags. All ranks now participate in checkpoint save.
---
## 2026-02-05: Auto Batch Size Scaling
### Background
@ -660,7 +749,7 @@ See the branch `fp8_attempt_fail` for:
### Open Questions
- Why does the custom op approach use more memory than vanilla BF16?
- Why is the bump in tok_per_sec so low? We should see ~1.6X speedup in both the forward pass and also (twice) in backward pass for the gradients. Granted, Ahmdal's law is part of the solution because our vocab_size is only 32K so the final layer isn't a huge part of the profile but the expected speedup is still not fully realized.
- Why is the bump in tok_per_sec so low? We should see ~1.6X speedup in both the forward pass and also (twice) in backward pass for the gradients. Granted, Amdahl's law is part of the solution because our vocab_size is only 32K so the final layer isn't a huge part of the profile but the expected speedup is still not fully realized.
**Conclusion:** Negative result for now. The implementation works correctly but provides marginal speedup with *increased* memory usage. I'm not understanding the torch.compile interaction here. The complexity of FP8 custom ops isn't justified for lm_head alone. TODO to study in more detail the way this is implemented in other libraries, e.g. torchao.
@ -824,7 +913,7 @@ Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon i
- Now defaults to ON for Muon via the `weight_decay` param. AdamW still has no weight decay and is hardcoded to 0 weight decay, might try to re-tune this later.
**4. Weight decay schedule**
- Added a linear schedule to weight decay that is default on from 1.0 to 0.0 (i.e. start with max weight decay in the beginning of training, them ramp to 0 by the end). Worked better than a static setting in experiments. (modded-nanogpt has the same schedule but it is imlpemented in a more confusing way by multiplying twice by the learning rate, which is already wired up to a decay schedule).
- Added a linear schedule to weight decay that is default on from 1.0 to 0.0 (i.e. start with max weight decay in the beginning of training, then ramp to 0 by the end). Worked better than a static setting in experiments. (modded-nanogpt has the same schedule but it is implemented in a more confusing way by multiplying twice by the learning rate, which is already wired up to a decay schedule).
### Weight Decay Scaling Experiments
@ -868,6 +957,6 @@ Muon was changed to use Polar Express, added NorMuon variance reduction, and cau
**Bug Found:** Original implementation clipped local gradients before sync. Since this codebase doesn't use DDP (gradient sync is in the optimizers), each rank was clipping based on its own local norm. Fixed on the branch with proper distributed all-reduce.
**Observartion:** modded-nanogpt does not appear to clip either right now.
**Observation:** modded-nanogpt does not appear to clip either right now.
**Summary:** Deleted all grad-clip code paths. The code naturally produces well-behaved gradients. This improves a bit of MFU because we don't have to calculate and sync grad norms.

View File

@ -186,6 +186,9 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None):
if step is None:
step = find_last_step(checkpoint_dir)
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
if not os.path.exists(optimizer_path):
log0(f"Optimizer checkpoint not found: {optimizer_path}")
return None
log0(f"Loading optimizer state from {optimizer_path}")
optimizer_data = torch.load(optimizer_path, map_location=device)
return optimizer_data

View File

@ -170,7 +170,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
# Precision
if device_type == "cuda":
torch.backends.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls, see https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()

View File

@ -123,19 +123,16 @@ def _to_col_major(x):
class _Float8Matmul(torch.autograd.Function):
"""Custom autograd for the three FP8 GEMMs of a Linear layer.
The forward saves input and weight in their original precision for the
backward pass. Each GEMM independently re-quantizes its operands to FP8.
(We don't reuse the forward's FP8 tensors in backward the backward might
want different precision, and saving FP8 would lose information.)
The forward quantizes input and weight to FP8 and saves
the quantized tensors + scales for backward.
"""
@staticmethod
def forward(ctx, input_2d, weight):
ctx.save_for_backward(input_2d, weight)
# Quantize both operands to e4m3 (higher precision format)
input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
ctx.save_for_backward(input_fp8, input_inv, weight_fp8, weight_inv)
# output = input @ weight.T
# input_fp8 is [B, K] contiguous = row-major (good for first arg)
@ -156,13 +153,12 @@ class _Float8Matmul(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
input_2d, weight = ctx.saved_tensors
in_fp8, in_inv, w_fp8, w_inv = ctx.saved_tensors
# === GEMM 1: grad_input = grad_output @ weight ===
# Shapes: [B, N] @ [N, K] -> [B, K]
# Gradients use e5m2 (wider range), weights use e4m3 (higher precision)
go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2)
w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn)
# go_fp8 is [B, N] contiguous = row-major, good for first arg
# w_fp8 is [N, K] contiguous = row-major, need column-major for second arg
w_col = _to_col_major(w_fp8)
@ -177,17 +173,15 @@ class _Float8Matmul(torch.autograd.Function):
# === GEMM 2: grad_weight = grad_output.T @ input ===
# Shapes: [N, B] @ [B, K] -> [N, K]
go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2)
in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
# go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg.
# go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg.
# Transposing gives column-major, but first arg needs row-major,
# so we must call .contiguous() to physically rearrange the memory.
go_T = go_fp8_2.t().contiguous() # [N, B] row-major
go_T = go_fp8.t().contiguous() # [N, B] row-major
in_col = _to_col_major(in_fp8) # [B, K] column-major
grad_weight = torch._scaled_mm(
go_T,
in_col,
scale_a=go_inv_2,
scale_a=go_inv,
scale_b=in_inv,
out_dtype=grad_output.dtype,
use_fast_accum=False,

View File

@ -69,7 +69,7 @@ python -m scripts.tok_eval
echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID
# d24 model (slightly overtrained is enough to beat GPT-2 => increase data:params ratio from compute optimal 10.5 (default) to 12)
# d26 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 8.25)
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --target-param-data-ratio=8.25 --device-batch-size=16 --fp8 --run=$WANDB_RUN
# evaluate the model: CORE metric, BPB on train/val, and draw samples
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16

View File

@ -170,20 +170,22 @@ if args.fp8:
# from torchao.float8 import Float8LinearConfig, convert_to_float8_training
import torch.nn as nn
# Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement)
# Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough
def fp8_module_filter(mod: nn.Module, fqn: str) -> bool:
if not isinstance(mod, nn.Linear):
return False
# FP8 requires both in_features and out_features divisible by 16
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
if min(mod.in_features, mod.out_features) < 128:
return False
return True
fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe)
num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter)
num_fp8_layers = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
num_skipped = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) - num_fp8_layers
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8_layers} layers, skipped {num_skipped} (dims not divisible by 16)")
num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
num_skipped = num_linear - num_fp8
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)")
# Context manager to temporarily disable FP8 so that model evaluation remains in BF16
@contextmanager

View File

@ -43,7 +43,7 @@ parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (e
# Model loading
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
parser.add_argument("--load-optimizer", type=int, default=0, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)")
parser.add_argument("--load-optimizer", type=int, default=1, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)")
# Training horizon
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
# Batch sizes (default: inherit from pretrained checkpoint)
@ -64,6 +64,9 @@ parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number o
parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)")
parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE")
parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE")
# Data mixture
parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)")
parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
@ -132,12 +135,21 @@ token_bytes = get_token_bytes(device=device)
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0)
# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.)
# Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the
# pretrained values. Since pretraining warmdown brings LRs to ~0, we must save and
# restore our fresh SFT LRs after loading.
base_dir = get_base_dir()
if args.load_optimizer:
optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step)
optimizer.load_state_dict(optimizer_data)
del optimizer_data
print0("Loaded optimizer state from pretrained checkpoint")
if optimizer_data is not None:
base_lrs = [group["lr"] for group in optimizer.param_groups]
optimizer.load_state_dict(optimizer_data)
del optimizer_data
for group, base_lr in zip(optimizer.param_groups, base_lrs):
group["lr"] = base_lr
print0("Loaded optimizer state from pretrained checkpoint (momentum buffers only, LRs reset)")
else:
print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)")
# Override the initial learning rate as a fraction of the base learning rate
for group in optimizer.param_groups:
@ -146,16 +158,17 @@ for group in optimizer.param_groups:
# SFT data mixture and DataLoader
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
train_dataset = TaskMixture([
train_tasks = [
SmolTalk(split="train"), # 460K rows of general conversations
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
GSM8K(subset="main", split="train"), # 2 epochs of GSM8K
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these
*[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch
*[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # total: 460K + 100K + 16K + 200K + 80K = 856K rows
]
train_dataset = TaskMixture(train_tasks)
print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})")
val_dataset = TaskMixture([
SmolTalk(split="test"), # 24K rows in test set
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios

View File

@ -31,7 +31,7 @@ class MockModel:
def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens
self.vocab_size = vocab_size
self.config = MockConfig()
self._device = "cpu"
self._device = torch.device("cpu")
def get_device(self):
return self._device