diff --git a/dev/LOG.md b/dev/LOG.md index c7d8b80..5f6e1d7 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,176 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-01-13: Varlen Attention (Negative Result) + +Attempted to prevent attention from "leaking" across document boundaries using Flash Attention's `flash_attn_varlen_func`, similar to modded-nanogpt's approach. + +### Background + +With the BOS-aligned dataloader, multiple documents are packed into each row. Standard attention allows tokens to attend across document boundaries within a row. The hypothesis was that preventing this "leakage" via varlen attention might improve training. + +### Approach: Compute cu_seqlens from inputs + +- Find BOS positions: `(inputs.view(-1) == bos_token_id).nonzero()` +- Gotcha 1: Variable-length `cu_seqlens` caused torch.compile recompilation (25s/iter!) - fixed by padding to fixed size +- Gotcha 2: `nonzero()` inside compiled model hit recompile limit - fixed by moving computation outside compiled region + +### Final Results (d16) + +| Metric | Baseline | Varlen | +|--------|----------|--------| +| val_bpb | 0.85427 | 0.85407 | +| MFU | ~same | ~same | +| tok/sec | ~same | ~same | + +Essentially identical. The 0.0002 bpb improvement is almost noise. + +### Conclusion + +Not worth the code complexity. The "leakage" across document boundaries within a row is not harmful - the model handles it fine. The BOS-aligned dataloader already provides the key benefit (every row starts with proper context). Not merging to master. + +--- + +## 2026-01-13: BOS-Aligned Dataloader with Bin Packing + +Redesigned the pretraining and midtraining dataloader to ensure every sequence starts with a BOS token, and explored bin-packing algorithms to minimize wasted tokens. + +### Problem Statement + +The original dataloader streams tokens into a flat buffer and reshapes into batches. This means some rows start mid-document (no BOS), which could confuse the model during training. We want every row to start with BOS and contain well-formed documents. + +### Approach 1: Greedy-Crop BOS (Simple) + +Each row is built independently: +- Start with a document (which has BOS prepended) +- Pack more documents until row is full +- If a document doesn't fit, **crop it** to fill remaining space (discard the rest) +- 100% utilization (no padding), but wastes cropped tokens + +### Waste Analysis + +Measured token waste empirically on real data (T=2048): +- **39.4% of tokens are cropped** (discarded when docs don't fit) +- **22.9% is the theoretical minimum** (tokens in docs longer than T+1 that can never fit) +- The extra ~16.5% comes from "unlucky" cropping when a long doc starts near the end of a row + +### Bin Packing Algorithms Explored + +| Algorithm | Util% | Crop% | Pad% | Notes | +|-----------|-------|-------|------|-------| +| Greedy-Crop (baseline) | 100% | 39.4% | 0% | Simple, no wasted compute | +| Greedy-Pad | 78% | 23.0% | 22% | Pads instead of crops - wastes compute | +| First-Fit Decreasing (FFD) | 99.7% | 23.0% | 0.3% | Near-optimal packing, minimal padding | +| **BestFit-Crop** | 100% | 34.6% | 0% | Smart cropping, no padding | + +### BestFit-Crop Algorithm + +A middle ground that maintains 100% utilization while reducing cropping: + +1. Buffer N documents +2. For each row, greedily pick the **largest doc that fits entirely** +3. Repeat until nothing fits +4. When nothing fits, crop a doc to fill remaining space exactly + +This avoids "unlucky" crops by searching the buffer for better-fitting documents. + +**Results (T=2048):** +- Crop waste reduced from 39.4% → 34.6% (~12% relative improvement) +- Still achieves 100% utilization (no padding, every token trains) +- Slightly more rows than baseline (uses more documents per batch) + +### Decision: Keep Two Implementations + +1. Keep the original implementation which is very simple, efficient and has 100% token utilization in the batch (no padding with ignore tokens), but creates slightly more confusing token streams for the LLM because documents during training can start abruptly from the middle with no context. Note that this never happens at test time, where BOS is always present. + +2. **`_bos_bestfit` (BestFit-Crop, new default)**: Slightly more complex but still keeps 100% token utilization in the batch (no padding), but at the cost of discarding documents when they don't fit. In practice, about 34% of tokens are discarded with this approach. This is ok because for most models we care about we have plenty of data without having to go to multiple epochs. One more subtle effect is that it does skew the data distribution a tiny bit because, reliably and necessarily, tokens at the tails of long documents will be discarded. However, this doesn't seem to impact actual downstream performance. + +### Midtraining + +The midtraining dataloader was also updated. Because conversations are on average a lot shorter than pretraining documents, only about 3.3% of tokens get cropped. + +### NOTE: loss scale + +Do note that switching to the BOS dataloader changes the validation loss and makes all previous experiments not comparable in absolute value of the loss, because we have a lot fewer "confusing" tokens in the train/val batches. All tokens can look back and find the BOS token and have the full context of that document to make predictions. Therefore, the loss appears lower but this is "fake" to some extent, and the expectation is that the vast majority of relative comparisons done so far would agree with those before and after this change. + +--- + +## 2026-01-13: Number Token Split Pattern + +Validated the `\p{N}{1,2}` pattern in `SPLIT_PATTERN` (tokenizer.py line 30), which I only guessed earlier and had a TODO for to validate. GPT-4 uses `\p{N}{1,3}` to group number sequences of up to 3 digits into tokens, but we suspected smaller vocab sizes benefit from grouping fewer digits per token. + +**Results (d12, vocab=32K):** +| Pattern | val_bpb | +|---------|---------| +| `\p{N}{1,1}` | 0.969 | +| `\p{N}{1,2}` | **0.965** | +| `\p{N}{1,3}` | 0.972 | + +**Conclusion:** `{1,2}` is optimal for vocab size 32K. Grouping 3 digits wastes tokens on rare 3-digit combinations; grouping 1 digit is too fine-grained and bloats token sequences. Keeping `{1,2}` as default. + +--- + +## 2026-01-13: FP8 Training for lm_head + +Attempted to use FP8 (8-bit floating point) for the lm_head layer to speed up the large vocab projection matmul. H100 GPUs have FP8 tensor cores that can theoretically provide ~2x speedup over BF16. + +### Implementation Approaches Tried + +**1. Dynamic Scaling (failed)** +- Compute `x.abs().max()` and `w.abs().max()` each forward to determine scales +- Problem: `.item()` calls cause graph breaks with torch.compile +- Tried `@torch._dynamo.allow_in_graph` pattern (like torchao.float8) - worked but no speedup +- Tried `torch.library.custom_op` with float scales - caused NaN gradients after first optimizer step +- Root cause: interaction between custom ops, dynamic scale computation, and torch.compile is fragile + +**2. Static Scaling (partial success)** +- Pre-set scales at init time like modded-nanogpt: `x_scale=10/448, w_scale=0.1/448` +- `grad_scale` computed dynamically from batch size (safe since it's just `1/(B*T)/57344` due to the gradient expression of cross entropy). modded-nanogpt has a bug here probably because they set `grad_scale = 0.75/448`, but grads are in E5M2 so this should probably be `1/57344`, 1 being the amax of any individual element of cross entropy loss, and no normalization by B,T because they use sum reduction not mean reduction. +- Uses `torch.library.custom_op` with `@torch.compile` on inner kernels +- This works correctly - no NaNs, proper gradients + +### Results (d12) + +| Metric | BF16 Baseline | FP8 lm_head | +|--------|---------------|-------------| +| GPU Memory | 34 GB | 36 GB | +| tok/sec | baseline | ~1% faster | + +### The Memory Mystery + +FP8 *should* save memory since we store `x_f8` (1 byte) instead of `x` (2 bytes) for backward. But we see 2GB *increase*. Suspected causes: +- `torch.compile` on inner kernels creating extra buffers/specializations +- `torch._scaled_mm` internal workspace allocations +- Custom op registration machinery overhead + +Tried saving original weight `w` (just a reference to parameter) instead of `w_f8` in backward, then re-quantizing on the spot during backward - didn't help. Still saw bump. + +### Microbenchmark vs Reality + +Raw microbenchmark showed promise: +- BF16 matmul: 16.95 ms +- FP8 matmul (static scales): 10.31 ms (1.64x faster) +- FP8 with dynamic scaling: 12.25 ms (1.38x faster) + +But in full training, the ~1% tok/sec improvement doesn't justify the 2GB memory increase and the added code complexity and the need to tune scale factors for both x and w. + +### Code Artifacts + +See the branch `fp8_attempt_fail` for: + +- `nanochat/fp8_static.py` - Static scaling implementation (working) +- `nanochat/fp8_dynamic.py` - Dynamic scaling implementation (torchao-style, working but slow) +- `gpt.py` imports `fp8_static.LinearFP8` and simply swaps it for `lm_head` in `gpt.py`. + +### Open Questions + +- Why does the custom op approach use more memory than vanilla BF16? +- Why is the bump in tok_per_sec so low? We should see ~1.6X speedup in both the forward pass and also (twice) in backward pass for the gradients. Granted, Ahmdal's law is part of the solution because our vocab_size is only 32K so the final layer isn't a huge part of the profile but the expected speedup is still not fully realized. + +**Conclusion:** Negative result for now. The implementation works correctly but provides marginal speedup with *increased* memory usage. I'm not understanding the torch.compile interaction here. The complexity of FP8 custom ops isn't justified for lm_head alone. TODO to study in more detail the way this is implemented in other libraries, e.g. torchao. + +--- + ## 2026-01-12: Multi-Token Prediction (MTP) Ported multi-token prediction from modded-nanogpt. Instead of predicting just the next token, predict the next n tokens at each position with weighted loss. diff --git a/dev/runcpu.sh b/dev/runcpu.sh index a58bfbc..c3ad290 100755 --- a/dev/runcpu.sh +++ b/dev/runcpu.sh @@ -25,7 +25,7 @@ python -m nanochat.report reset # train tokenizer on ~1B characters python -m nanochat.dataset -n 6 -python -m scripts.tok_train --max_chars=1000000000 +python -m scripts.tok_train --max-chars=1000000000 python -m scripts.tok_eval # train a very small 4 layer model on the CPU @@ -33,37 +33,37 @@ python -m scripts.tok_eval # we only run 50 steps of optimization (bump this to get better results) python -m scripts.base_train \ --depth=4 \ - --max_seq_len=1024 \ - --device_batch_size=1 \ - --total_batch_size=1024 \ - --eval_every=50 \ - --eval_tokens=4096 \ - --core_metric_every=50 \ - --core_metric_max_per_task=12 \ - --sample_every=50 \ - --num_iterations=50 -python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096 + --max-seq-len=1024 \ + --device-batch-size=1 \ + --total-batch-size=1024 \ + --eval-every=50 \ + --eval-tokens=4096 \ + --core-metric-every=50 \ + --core-metric-max-per-task=12 \ + --sample-every=50 \ + --num-iterations=50 +python -m scripts.base_loss --device-batch-size=1 --split-tokens=4096 python -m scripts.base_eval --max-per-task=16 # midtraining python -m scripts.mid_train \ - --max_seq_len=1024 \ - --device_batch_size=1 \ - --eval_every=50 \ - --eval_tokens=4096 \ - --total_batch_size=1024 \ - --num_iterations=100 + --max-seq-len=1024 \ + --device-batch-size=1 \ + --eval-every=50 \ + --eval-tokens=4096 \ + --total-batch-size=1024 \ + --num-iterations=100 # eval results will be terrible, this is just to execute the code paths. # note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20 # SFT python -m scripts.chat_sft \ - --device_batch_size=1 \ - --target_examples_per_step=4 \ - --num_iterations=100 \ - --eval_steps=4 \ - --eval_metrics_max_problems=16 + --device-batch-size=1 \ + --target-examples-per-step=4 \ + --num-iterations=100 \ + --eval-steps=4 \ + --eval-metrics-max-problems=16 # Chat CLI # python -m scripts.chat_cli -p "Why is the sky blue?" diff --git a/miniseries.sh b/miniseries.sh index 0a6947e..9a4512b 100644 --- a/miniseries.sh +++ b/miniseries.sh @@ -17,9 +17,10 @@ if [ -z "$SKIP_SETUP" ]; then uv sync --extra gpu source .venv/bin/activate - # Tokenizer - python -m nanochat.dataset -n 240 - python -m scripts.tok_train --max_chars=2000000000 --vocab_size=32768 + # Tokenizer, download 1000 shards for pretraining + # (probably this can be reduced but it's tricky to determine the exact right number, TODO). + python -m nanochat.dataset -n 1000 + python -m scripts.tok_train --max-chars=2000000000 --vocab-size=32768 else source .venv/bin/activate fi @@ -57,16 +58,16 @@ for d in "${DEPTHS[@]}"; do START_TIME=$(date +%s) # Train the model with natural horizon (target_param_data_ratio default) - # No --target_flops, let it use the default ratio from base_train + # No --target-flops, let it use the default ratio from base_train torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \ --depth=$d \ - --target_param_data_ratio=8 \ + --target-param-data-ratio=8 \ --run="${WANDB_RUN}_d${d}" \ - --model_tag="${TAG}" \ - --core_metric_every=999999 \ - --core_metric_max_per_task=-1 \ - --sample_every=-1 \ - --save_every=-1 \ + --model-tag="${TAG}" \ + --core-metric-every=999999 \ + --core-metric-max-per-task=-1 \ + --sample-every=-1 \ + --save-every=-1 \ 2>&1 | tee "$RESULTS_DIR/${TAG}_train.log" END_TIME=$(date +%s) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index cca6294..d1e0a07 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -25,6 +25,7 @@ def _patch_missing_config_keys(model_config_kwargs): # Old models were trained with full context (no sliding window) if "window_pattern" not in model_config_kwargs: model_config_kwargs["window_pattern"] = "L" + log0(f"Patching missing window_pattern in model config to 'L'") def _patch_missing_keys(model_data, model_config): """Add default values for new parameters that may be missing in old checkpoints.""" @@ -32,9 +33,11 @@ def _patch_missing_keys(model_data, model_config): # resid_lambdas defaults to 1.0 (identity scaling) if "resid_lambdas" not in model_data: model_data["resid_lambdas"] = torch.ones(n_layer) + log0(f"Patching missing resid_lambdas in model data to 1.0") # x0_lambdas defaults to 0.0 (disabled) if "x0_lambdas" not in model_data: model_data["x0_lambdas"] = torch.zeros(n_layer) + log0(f"Patching missing x0_lambdas in model data to 0.0") def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0): if rank == 0: @@ -108,7 +111,7 @@ def build_model(checkpoint_dir, step, device, phase): # Load the Tokenizer tokenizer = get_tokenizer() # Sanity check: compatibility between model and tokenizer - assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"] + assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"], f"Tokenizer vocab size {tokenizer.get_vocab_size()} does not match model config vocab size {model_config_kwargs['vocab_size']}" return model, tokenizer, meta_data diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 20dd88f..562d517 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -1,4 +1,25 @@ -from collections import deque +""" +Distributed dataloaders for pretraining. + +Two implementations are provided: + +1. Original (tokenizing_distributed_data_loader): + - Streams tokens into a flat buffer, reshapes to (B, T) + - Rows may start mid-document (no guaranteed BOS at position 0) + - 100% token utilization, simple and efficient + +2. BOS-aligned bestfit (tokenizing_distributed_data_loader_bos_bestfit): + - Every row starts with BOS token + - Documents packed using best-fit algorithm to minimize cropping + - When no document fits remaining space, crops a document to fill exactly + - 100% utilization (no padding), ~35% tokens cropped at T=2048 + +The tradeoff: BOS-aligned loses ~35% of tokens to cropping, but ensures that +there are fewer "confusing" tokens in the train/val batches as every token can +now attend back to the BOS token and sees the full context of the document. +(2) is the new default if you have enough data. +Fallback to (1) if you have very limited data AND long documents. +""" import torch import pyarrow.parquet as pq @@ -6,86 +27,172 @@ import pyarrow.parquet as pq from nanochat.common import get_dist_info from nanochat.dataset import list_parquet_files +def _document_batches(split, resume_state_dict, tokenizer_batch_size): + """ + Infinite iterator over document batches (list of text strings) from parquet files. + + Handles DDP sharding and approximate resume. Each yield is (text_batch, (pq_idx, rg_idx, epoch)) + where text_batch is a list of document strings, indices track position for resumption, + and epoch counts how many times we've cycled through the dataset (starts at 1). + """ + ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() + + parquet_paths = list_parquet_files() + assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?" + parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] + + resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0 + resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None + resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1 + first_pass = True + pq_idx = resume_pq_idx + epoch = resume_epoch + + while True: # iterate infinitely (multi-epoch) + pq_idx = resume_pq_idx if first_pass else 0 + while pq_idx < len(parquet_paths): + filepath = parquet_paths[pq_idx] + pf = pq.ParquetFile(filepath) + # Start from resume point if resuming on same file, otherwise from DDP rank + if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx): + base_idx = resume_rg_idx // ddp_world_size + base_idx += 1 # advance by 1 so we don't repeat data after resuming + rg_idx = base_idx * ddp_world_size + ddp_rank + if rg_idx >= pf.num_row_groups: + pq_idx += 1 + continue + resume_rg_idx = None # only do this once + else: + rg_idx = ddp_rank + while rg_idx < pf.num_row_groups: + rg = pf.read_row_group(rg_idx) + batch = rg.column('text').to_pylist() + for i in range(0, len(batch), tokenizer_batch_size): + yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx, epoch) + rg_idx += ddp_world_size + pq_idx += 1 + first_pass = False + epoch += 1 + + def tokenizing_distributed_data_loader_with_state(tokenizer, B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None): """ Stream pretraining text from parquet files, tokenize, yield training batches. - This implementation became a bit more complex because we wish to support approximate resume training. - Instead of turning this into a Class, we opt to return the state_dict with every batch, - and then the caller can pass in a state_dict to resume training from a desired point. - Note that this resumption is atm only *approximate* for simplicity. - We won't repeat the same documents but we might skip a few. - The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume. + This is the original dataloader that streams tokens into a flat buffer and reshapes. + Rows may start mid-document (no guaranteed BOS at position 0). - Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm. + Supports approximate resume via state_dict. """ assert split in ["train", "val"], "split must be 'train' or 'val'" - # infinite iterator over document batches (list of text strings) - ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() - def document_batches(): - parquet_paths = list_parquet_files() - assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?" - parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] - resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0 - resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None - first_pass = True - pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0) - while True: # iterate infinitely (multi-epoch) - pq_idx = resume_pq_idx if first_pass else 0 - while pq_idx < len(parquet_paths): # iterate over all parquet files - filepath = parquet_paths[pq_idx] - pf = pq.ParquetFile(filepath) - # Start from resume point if resuming on same file, otherwise from DDP rank - # I know this state resumption is a little bit tricky and a little bit hacky... sigh. - if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx): - base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size - base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming - rg_idx = base_idx * ddp_world_size + ddp_rank - if rg_idx >= pf.num_row_groups: - pq_idx += 1 - continue - resume_rg_idx = None # set to None as we only want to do this a single time - else: - rg_idx = ddp_rank - while rg_idx < pf.num_row_groups: - rg = pf.read_row_group(rg_idx) - batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows - # the tokenizer encode might want to go in even smaller batches, e.g. 128 rows - for i in range(0, len(batch), tokenizer_batch_size): - yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx) - rg_idx += ddp_world_size # advance to the next row group (in DDP) - pq_idx += 1 # advance to the next parquet file - first_pass = False - batches = document_batches() - - # Now emit batches of tokens. - needed_tokens = B * T + 1 # +1 is because we also need the target at the last token + batches = _document_batches(split, resume_state_dict, tokenizer_batch_size) + needed_tokens = B * T + 1 # +1 for target at last position bos_token = tokenizer.get_bos_token_id() - # scratch buffer holds the tokens for one iteration - token_buffer = deque() # we stream tokens on the right and pop from the left + token_buffer = [] + pq_idx, rg_idx, epoch = 0, 0, 1 + while True: - # Accumulate enough tokens for one iteration before yielding. + + # Accumulate enough tokens while len(token_buffer) < needed_tokens: - doc_batch, (pq_idx, rg_idx) = next(batches) + doc_batch, (pq_idx, rg_idx, epoch) = next(batches) token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads) for tokens in token_lists: token_buffer.extend(tokens) - # Move tokens from the deque into the scratch buffer - tokens = [token_buffer.popleft() for _ in range(needed_tokens)] - # CUDA supports memory pinning for asynchronous transfers between CPU and GPU - use_cuda_optimizations = device == "cuda" - scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64 - # Create the inputs/targets as 1D tensors - inputs_cpu = scratch[:-1] - targets_cpu = scratch[1:] - # Reshape to 2D and move to GPU async - inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations) - targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations) - state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training - yield inputs, targets, state_dict + tokens = token_buffer[:needed_tokens] # Read B*T+1 tokens (+1 is only for the target for the last token) + token_buffer = token_buffer[B*T:] # Advance by B*T tokens, so we move exactly one window of B*T tokens over + + # Package tokens into inputs and targets, yield + use_cuda = device == "cuda" + scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda) + inputs = scratch[:-1].view(B, T).to(device=device, non_blocking=use_cuda) + targets = scratch[1:].view(B, T).to(device=device, non_blocking=use_cuda) + yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch} + def tokenizing_distributed_data_loader(*args, **kwargs): - # helper function that only emits the inputs/targets and not the state_dict + """Helper that omits state_dict from yields.""" for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs): yield inputs, targets + + +def tokenizing_distributed_data_loader_with_state_bos_bestfit( + tokenizer, B, T, split, + tokenizer_threads=4, tokenizer_batch_size=128, + device="cuda", resume_state_dict=None, + buffer_size=1000 +): + """ + BOS-aligned dataloader with Best-Fit Cropping. + + Reduces token waste compared to simple greedy cropping by searching a buffer + for documents that fit well, while maintaining 100% utilization (no padding). + + Algorithm for each row: + 1. From buffered docs, pick the LARGEST doc that fits entirely + 2. Repeat until no doc fits + 3. When nothing fits, crop a doc to fill remaining space exactly + + Key properties: + - Every row starts with BOS + - 100% utilization (no padding, every token is trained on) + - Approximately 35% of all tokens are discarded due to cropping + """ + assert split in ["train", "val"], "split must be 'train' or 'val'" + + row_capacity = T + 1 + batches = _document_batches(split, resume_state_dict, tokenizer_batch_size) + bos_token = tokenizer.get_bos_token_id() + doc_buffer = [] + pq_idx, rg_idx, epoch = 0, 0, 1 + + def refill_buffer(): + nonlocal pq_idx, rg_idx, epoch + doc_batch, (pq_idx, rg_idx, epoch) = next(batches) + token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads) + for tokens in token_lists: + doc_buffer.append(tokens) + + while True: + rows = [] + for _ in range(B): + row = [] + while len(row) < row_capacity: + # Ensure buffer has documents + while len(doc_buffer) < buffer_size: + refill_buffer() + + remaining = row_capacity - len(row) + + # Find largest doc that fits entirely + best_idx = -1 + best_len = 0 + for i, doc in enumerate(doc_buffer): + doc_len = len(doc) + if doc_len <= remaining and doc_len > best_len: + best_idx = i + best_len = doc_len + + if best_idx >= 0: + doc = doc_buffer.pop(best_idx) + row.extend(doc) + else: + # No doc fits - crop first doc to fill remaining + doc = doc_buffer.pop(0) + row.extend(doc[:remaining]) + + rows.append(row[:row_capacity]) + + use_cuda = device == "cuda" + batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda) + inputs = batch_tensor[:, :-1].to(device=device, non_blocking=use_cuda) + targets = batch_tensor[:, 1:].to(device=device, non_blocking=use_cuda) + + yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch} + + +def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs): + """Helper that omits state_dict from yields.""" + for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs): + yield inputs, targets diff --git a/nanochat/muon.py b/nanochat/muon.py index 7ae5ffd..cfd2443 100644 --- a/nanochat/muon.py +++ b/nanochat/muon.py @@ -1,7 +1,27 @@ """ -Muon optimizer adapted (simplified) from modded-nanogpt. +Muon optimizer adapted and simplified from modded-nanogpt. https://github.com/KellerJordan/modded-nanogpt + +Background: +Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a +quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose +of minimizing steps, it turns out to be empirically effective to keep increasing the slope at +zero even beyond the point where the iteration no longer converges all the way to one everywhere +on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T +where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model +performance at all relative to UV^T, where USV^T = G is the SVD. + +Here, an alternative to Newton-Schulz iteration with potentially better convergence properties: +Polar Express Sign Method for orthogonalization. +https://arxiv.org/pdf/2505.16932 +by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + +Some of the changes in nanochat implementation: +- Uses a simpler, more general approach to parameter grouping and stacking +- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step +- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format) """ + import torch from torch import Tensor import torch.distributed as dist @@ -16,97 +36,61 @@ polar_express_coeffs = [ (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), ] - -@torch.compile -def zeropower_via_polar_express(G: Tensor, steps: int = 5) -> Tensor: +@torch.compile(dynamic=False, fullgraph=True) +def muon_step_fused( + stacked_grads: Tensor, + stacked_params: Tensor, + momentum_buffer: Tensor, + second_momentum_buffer: Tensor, + momentum_t: Tensor, + lr_t: Tensor, + wd_t: Tensor, + beta2_t: Tensor, + ns_steps: int, + red_dim: int, +) -> None: """ - Polar Express Sign Method for orthogonalization. - https://arxiv.org/pdf/2505.16932 - by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. - - Alternative to Newton-Schulz iteration with potentially better convergence properties. + Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update + All in one compiled graph to eliminate Python overhead between ops. + Some of the constants are 0-D CPU tensors to avoid recompilation when values change. """ - assert G.ndim >= 2 - X = G.bfloat16() - if G.size(-2) > G.size(-1): + + # Nesterov momentum + momentum = momentum_t.to(stacked_grads.dtype) + momentum_buffer.lerp_(stacked_grads, 1 - momentum) + g = stacked_grads.lerp_(momentum_buffer, momentum) + + # Polar express + X = g.bfloat16() + if g.size(-2) > g.size(-1): X = X.mT - - # Ensure spectral norm is at most 1 (with 2% safety factor) X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) - - # Perform the iterations (cap at available coefficients) - for a, b, c in polar_express_coeffs[:min(steps, len(polar_express_coeffs))]: + for a, b, c in polar_express_coeffs[:ns_steps]: A = X @ X.mT B = b * A + c * (A @ A) X = a * X + B @ X - - if G.size(-2) > G.size(-1): + if g.size(-2) > g.size(-1): X = X.mT - return X + g = X - -@torch.compile -def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(-2) > G.size(-1): - X = X.mT - - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) - # Perform the NS iterations - for _ in range(steps): - A = X @ X.mT - B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng - X = a * X + B @ X - - if G.size(-2) > G.size(-1): - X = X.mT - return X - - -@torch.compile -def apply_variance_reduction(v: Tensor, second_momentum_buffer: Tensor, beta2: float) -> Tensor: - """ - NorMuon-style variance reduction, similar to Adafactor's low-rank variance estimator. - https://arxiv.org/pdf/2510.05491 - - Normalizes updates based on a running estimate of per-row (or per-column) variance. - The reduction dimension is determined by the shape of second_momentum_buffer. - """ - # Determine reduction dimension from buffer shape - red_dim = -1 if second_momentum_buffer.size(-1) == 1 else -2 - - # Compute per-row/col mean of squared values - v_mean = v.float().square().mean(dim=red_dim, keepdim=True) - red_dim_size = v.size(red_dim) - - # Compute current norm + # Variance reduction + beta2 = beta2_t.to(g.dtype) + v_mean = g.float().square().mean(dim=red_dim, keepdim=True) + red_dim_size = g.size(red_dim) v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size v_norm = v_norm_sq.sqrt() - - # Update second momentum buffer (EMA of variance) second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) - - # Compute scaling factor from second momentum step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt() scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt() - - # Final scale preserves overall norm while adjusting per-row/col final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10)) - return v.mul(final_scale.to(v.dtype)) + g = g * final_scale.to(g.dtype) + # Cautious weight decay + parameter update + lr = lr_t.to(g.dtype) + wd = wd_t.to(g.dtype) + mask = (g * stacked_params) >= 0 + stacked_params.sub_(lr * g + lr * wd * stacked_params * mask) class Muon(torch.optim.Optimizer): """ @@ -127,94 +111,112 @@ class Muon(torch.optim.Optimizer): Arguments: lr: The learning rate used by the internal SGD. momentum: The momentum used by the internal SGD. - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iteration steps to use. beta2: The decay rate for the second moment (variance) estimate. Set to None to disable. weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree. """ - def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5, beta2=0.95, weight_decay=0.0): - defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay) - params: list[Tensor] = [*params] + def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0): + defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay) + assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only" + params = list(params) # ensure we have a list, not an e.g. (exhaustible) iterator + # Group by shape so we can stack tensors + shapes = sorted({p.shape for p in params}) param_groups = [] - for size in {p.numel() for p in params}: - group = dict(params=[p for p in params if p.numel() == size]) - param_groups.append(group) + for shape in shapes: + group_params = [p for p in params if p.shape == shape] + param_groups.append(dict(params=group_params)) super().__init__(param_groups, defaults) + # 0-D CPU tensors to avoid torch.compile recompilation when values change + self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") @torch.no_grad() def step(self): for group in self.param_groups: params: list[Tensor] = group["params"] - for p in params: - g = p.grad - assert g is not None - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf: Tensor = state["momentum_buffer"] - buf.lerp_(g, 1 - group["momentum"]) - g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf - g = zeropower_via_polar_express(g, steps=group["ns_steps"]) - # Variance reduction (NorMuon-style) - if group["beta2"] is not None: - if "second_momentum_buffer" not in state: - # Buffer shape determines reduction dim: reduce along larger dimension - if p.size(-2) >= p.size(-1): - state["second_momentum_buffer"] = torch.zeros_like(g[..., :1]) - else: - state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :]) - g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"]) - # Parameter update with cautious weight decay - effective_lr = group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5 - wd = group["weight_decay"] - if wd != 0: - mask = (g * p) >= 0 - p.sub_(effective_lr * g + effective_lr * wd * p * mask) + if not params: + continue + + # Get or create group-level buffers (stored in first param's state for convenience) + state = self.state[params[0]] + num_params = len(params) # e.g.: 12 (for a d12 model) + # e.g.: shape = (768, 3072), device = cuda:0, dtype = torch.float32, for one of the MLP projections + shape, device, dtype = params[0].shape, params[0].device, params[0].dtype + + # Momentum for every individual parameter + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device) + momentum_buffer = state["momentum_buffer"] # e.g.: (12, 768, 3072) + + # Second momentum buffer is factored, either per-row or per-column + if "second_momentum_buffer" not in state: + if shape[-2] >= shape[-1]: + state["second_momentum_buffer"] = torch.zeros(num_params, shape[-2], 1, dtype=dtype, device=device) else: - p.sub_(effective_lr * g) + state["second_momentum_buffer"] = torch.zeros(num_params, 1, shape[-1], dtype=dtype, device=device) + second_momentum_buffer = state["second_momentum_buffer"] # (12, 1, 3072) + red_dim = -1 if shape[-2] >= shape[-1] else -2 # e.g.: -2 + + # Stack grads and params + stacked_grads = torch.stack([p.grad for p in params]) # (12, 768, 3072) + stacked_params = torch.stack(params) # (12, 768, 3072) + + # Fill all the 0-D tensors with current values + self._momentum_t.fill_(group["momentum"]) + self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) + self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5) + self._wd_t.fill_(group["weight_decay"]) + + # Single fused kernel: momentum -> polar_express -> variance_reduction -> update + muon_step_fused( + stacked_grads, + stacked_params, + momentum_buffer, + second_momentum_buffer, + self._momentum_t, + self._lr_t, + self._wd_t, + self._beta2_t, + group["ns_steps"], + red_dim, + ) + + # Copy back to original params: [(768, 3072), (768, 3072), ...] <- (12, 768, 3072) + torch._foreach_copy_(params, list(stacked_params.unbind(0))) class DistMuon(torch.optim.Optimizer): """ - Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Polar Express, - finally apply aspect-ratio scaled step. Performs its own distributed synchronization: - - reduce_scatter(AVG) for gradient averaging - - all_gather to replicate updated weights - - Notes: - * Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D - params like embeddings or scalars. - * Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen - by block-cyclic assignment below). If you checkpoint optimizer state on a single rank, - consolidate states beforehand. - - Args: - params: iterable of Tensors - lr: learning rate - momentum: momentum coefficient in [0,1) - nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf - ns_steps: number of Newton-Schulz iterations for the orthogonalization - beta2: decay rate for second moment (variance) estimate. Set to None to disable. - weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree. + Distributed version of the Muon optimizer. """ def __init__(self, params, lr: float = 0.02, momentum: float = 0.95, - nesterov: bool = True, ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0): - defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay) - params = list(params) + ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0): + defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay) assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only" + params = list(params) + world_size = dist.get_world_size() rank = dist.get_rank() # Group all parameters by their shape - shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering + shapes = sorted({p.shape for p in params}) # sort for deterministic ordering across ranks param_groups = [] for shape in shapes: group_params = [p for p in params if p.shape == shape] device, dtype = group_params[0].device, group_params[0].dtype assert all(p.device == device for p in group_params) assert all(p.dtype == dtype for p in group_params) + # Compute chunk size for this group (how many params each rank owns) + chunk_size = (len(group_params) + world_size - 1) // world_size if rank == 0: - print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}") - param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0]))) + print(f"Muon: {len(group_params)} params of shape {shape}, chunk_size={chunk_size}") + param_groups.append(dict(params=group_params, chunk_size=chunk_size)) super().__init__(param_groups, defaults) + # 0-D CPU tensors to avoid torch.compile recompilation when values change + self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") @torch.no_grad() def step(self): @@ -224,72 +226,127 @@ class DistMuon(torch.optim.Optimizer): # Ensure all grads exist assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads" - # Kick off all the reduce scatter operations to average up the gradients across all ranks - all_reduce_futures = [] + # First pass: stack grads and kick off reduce_scatter for each group + group_infos = [] for group in self.param_groups: - params = group["params"] - zero_buffer = group["zero_buffer"] - # Go through params in groups of world_size. - for base_i in range(0, len(params), world_size): - # The compute owner of each param is rank i % world_size - owner_idx = base_i + rank - # each rank stacks up its chunk of world_size params into a list - rs_input = [p.grad for p in params[base_i:base_i + world_size]] - # pad rs_input with the zero buffer to complete the group - rs_input.extend([zero_buffer] * (world_size - len(rs_input))) - # the output buffer gets strided across the group based on the rank - rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer) - # reduce scatter the gradients within this group of world_size params - work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future() - all_reduce_futures.append(work) + params: list[Tensor] = group["params"] + chunk_size = group["chunk_size"] + padded_num_params = chunk_size * world_size + shape = params[0].shape + device, dtype = params[0].device, params[0].dtype - # Now each rank computes the update and gathers - future_idx = 0 + # Stack all gradients into a single tensor (single kernel via torch.stack) + grad_stack = torch.stack([p.grad for p in params]) + stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device) + stacked_grads[:len(params)].copy_(grad_stack) + # Zero-pad if we have fewer params than padded size + if len(params) < padded_num_params: + stacked_grads[len(params):].zero_() + + # Output buffer for this rank's chunk + grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device) + + # Async reduce_scatter on the stacked tensor + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append(dict( + grad_chunk=grad_chunk, + reduce_future=reduce_future, + stacked_grads=stacked_grads, # reuse for all_gather output + )) + + # Second pass: wait for reduce, compute batched updates, kick off all_gather all_gather_futures = [] - for group in self.param_groups: - params = group["params"] - zero_buffer = group["zero_buffer"] - # Go through params in groups of world_size. - for base_i in range(0, len(params), world_size): - # The compute owner of each param is rank i % world_size - owner_idx = base_i + rank # calculate the index of the param that this rank owns - # Wait for the reduce scatter to complete - all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead - future_idx += 1 - # Owner computes the Muon update, result is in its param - if owner_idx < len(params): - p = params[owner_idx] - g = p.grad # now averaged across ranks - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf: Tensor = state["momentum_buffer"] - buf.lerp_(g, 1.0 - group["momentum"]) - g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf - g = zeropower_via_polar_express(g, steps=group["ns_steps"]) - # Variance reduction (NorMuon-style) - if group["beta2"] is not None: - if "second_momentum_buffer" not in state: - # Buffer shape determines reduction dim: reduce along larger dimension - if p.size(-2) >= p.size(-1): - state["second_momentum_buffer"] = torch.zeros_like(g[..., :1]) - else: - state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :]) - g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"]) - # Parameter update with cautious weight decay - effective_lr = group["lr"] * (max(1.0, p.size(-2) / p.size(-1)) ** 0.5) - wd = group["weight_decay"] - if wd != 0: - mask = (g * p) >= 0 - p.sub_(effective_lr * g + effective_lr * wd * p * mask) - else: - p.sub_(effective_lr * g) - # Replicate updated parameters to all ranks - ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer - ag_output = params[base_i:base_i + world_size] - ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad - work = dist.all_gather(ag_output, ag_input, async_op=True).get_future() - all_gather_futures.append(work) + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() - # Wait for all work to finish - torch.futures.collect_all(all_gather_futures).wait() + params = group["params"] + chunk_size = group["chunk_size"] + shape = params[0].shape + device, dtype = params[0].device, params[0].dtype + grad_chunk = info["grad_chunk"] + + # How many params does this rank actually own? + start_idx = rank * chunk_size + num_owned = min(chunk_size, max(0, len(params) - start_idx)) + + # Get or create group-level state (stored keyed by first param) + state = self.state[params[0]] + + # Momentum buffer + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device) + momentum_buffer = state["momentum_buffer"] + + # Second momentum buffer is factored, either per-row or per-column + if "second_momentum_buffer" not in state: + if shape[-2] >= shape[-1]: + state["second_momentum_buffer"] = torch.zeros(chunk_size, shape[-2], 1, dtype=dtype, device=device) + else: + state["second_momentum_buffer"] = torch.zeros(chunk_size, 1, shape[-1], dtype=dtype, device=device) + second_momentum_buffer = state["second_momentum_buffer"] + red_dim = -1 if shape[-2] >= shape[-1] else -2 + + # Build updated_params tensor for all_gather + updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device) + + if num_owned > 0: + # Stack owned params (single kernel via torch.stack) + owned_params = [params[start_idx + i] for i in range(num_owned)] + stacked_owned_params = torch.stack(owned_params) + + # Get owned slices of buffers and grads + owned_grads = grad_chunk[:num_owned] + owned_momentum = momentum_buffer[:num_owned] + owned_second_momentum = second_momentum_buffer[:num_owned] + + # Fill 0-D tensors with current values + self._momentum_t.fill_(group["momentum"]) + self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) + self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5) + self._wd_t.fill_(group["weight_decay"]) + + # Single fused kernel: momentum -> polar_express -> variance_reduction -> update + muon_step_fused( + owned_grads, + stacked_owned_params, + owned_momentum, + owned_second_momentum, + self._momentum_t, + self._lr_t, + self._wd_t, + self._beta2_t, + group["ns_steps"], + red_dim, + ) + + # Copy updated params to output buffer + updated_params[:num_owned].copy_(stacked_owned_params) + + # Zero-pad the rest (for ranks that own fewer params) + if num_owned < chunk_size: + updated_params[num_owned:].zero_() + + # Reuse stacked_grads buffer for all_gather output + stacked_params = info["stacked_grads"] + + # Async all_gather to replicate updated params to all ranks + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_params, async_op=True + ).get_future() + + all_gather_futures.append(dict( + gather_future=gather_future, + stacked_params=stacked_params, + params=params, + )) + + # Final pass: wait for all_gather and copy back to params + for info in all_gather_futures: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + params = info["params"] + # Batched copy back (single kernel instead of N individual copies) + torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0))) diff --git a/nanochat/tokenizer.py b/nanochat/tokenizer.py index e8ccafa..a2146c2 100644 --- a/nanochat/tokenizer.py +++ b/nanochat/tokenizer.py @@ -26,7 +26,7 @@ SPECIAL_TOKENS = [ # NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3} # I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes. -# I haven't validated that this is actually a good idea, TODO. +# I verified that 2 is the sweet spot for vocab size of 32K. 1 is a bit worse, 3 was worse still. SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" # ----------------------------------------------------------------------------- diff --git a/run1000.sh b/run1000.sh index 669b279..c04583b 100644 --- a/run1000.sh +++ b/run1000.sh @@ -20,18 +20,18 @@ curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-publ # train tokenizer on ~4B characters and kick off download of the rest for pretraining python -m nanochat.dataset -n 21 -# start downloading the rest of the shards for a total of 800 (see below why 800) -python -m nanochat.dataset -n 800 & +# start downloading the rest of the shards for a total of 1200 (see below why 1200) +python -m nanochat.dataset -n 1200 & # todo: download the rest of it -python -m scripts.tok_train --max_chars=4000000000 --vocab_size=65536 +python -m scripts.tok_train --max-chars=4000000000 --vocab-size=65536 python -m scripts.tok_eval # Documenting my process for determining the hyperparameters for this run1000.sh script: # We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute # 1) I guessed the model size for this to be about depth=32 # 2) Determine the device_batch_size that fits: -# Running the base_train.py script with --depth=32, I saw that --device_batch_size=16 -# runs out of memory, but --device_batch_size=8 fits. Inspecting `nvidia-smi` during training, +# Running the base_train.py script with --depth=32, I saw that --device-batch-size=16 +# runs out of memory, but --device-batch-size=8 fits. Inspecting `nvidia-smi` during training, # I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%. # So the training script was running ok and showed: # Vocab size: 65,536 @@ -62,7 +62,9 @@ python -m scripts.tok_eval # The tok_eval.py script reports about ~4.8 chars/token on average for the default tokenizer settings. # So ~38B tokens # ~4.8 chars/token = ~185B chars. # Each data shard is ~250M chars, so we need ~185B / 250M ~= 740 shards. -# For safety, I bumped that up to 800 shards, and that's why up above I used -n 800 when pre-downloading dataset shards. +# For safety, I bumped that up to 800 shards. +# The new DataLoader wastes about 35% of tokens to cropping, so 800 / (1 - 0.35) ~= 1200 shards are needed. +# => why up above I used -n 1200 when pre-downloading dataset shards. # If we didn't have enough data, the training script would loop around and do multiple epochs over the same data, # which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd # start to overfit hard. @@ -71,13 +73,13 @@ python -m scripts.tok_eval # Number of processes/GPUs to use NPROC_PER_NODE=8 -torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target_param_data_ratio=20 --device_batch_size=8 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target-param-data-ratio=20 --device-batch-size=8 --run=$WANDB_RUN torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval # midtrain # NOTE: ensure that we use the same device_batch_size here as the base training script. -torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device-batch-size=8 --run=$WANDB_RUN torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid # sft diff --git a/scaling_laws.sh b/scaling_laws.sh index 102ba11..321b286 100644 --- a/scaling_laws.sh +++ b/scaling_laws.sh @@ -64,15 +64,15 @@ for flops in "${FLOPS_BUDGETS[@]}"; do # CORE eval happens once at the end (999999 ensures only final step) torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \ --depth=$d \ - --target_flops=$flops \ - --target_param_data_ratio=-1 \ + --target-flops=$flops \ + --target-param-data-ratio=-1 \ --run="${WANDB_RUN}_${TAG}" \ - --model_tag="${TAG}" \ - --eval_tokens=$EVAL_TOKENS \ - --core_metric_every=999999 \ - --core_metric_max_per_task=-1 \ - --sample_every=-1 \ - --save_every=-1 \ + --model-tag="${TAG}" \ + --eval-tokens=$EVAL_TOKENS \ + --core-metric-every=999999 \ + --core-metric-max-per-task=-1 \ + --sample-every=-1 \ + --save-every=-1 \ 2>&1 | tee "$RESULTS_DIR/${TAG}_train.log" END_TIME=$(date +%s) diff --git a/scripts/base_loss.py b/scripts/base_loss.py index 094299a..6b44a30 100644 --- a/scripts/base_loss.py +++ b/scripts/base_loss.py @@ -7,14 +7,14 @@ Example run as: torchrun --standalone --nproc_per_node=8 -m scripts.base_loss To evaluate a HuggingFace model: -python -m scripts.base_loss --hf_path openai-community/gpt2 +python -m scripts.base_loss --hf-path openai-community/gpt2 """ import argparse from contextlib import nullcontext import torch from nanochat.checkpoint_manager import load_model from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type -from nanochat.dataloader import tokenizing_distributed_data_loader +from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit from nanochat.tokenizer import get_token_bytes, HuggingFaceTokenizer from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine @@ -61,12 +61,12 @@ def get_hf_token_bytes(tokenizer, device="cpu"): # CLI arguments parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model") -parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size") -parser.add_argument("--split_tokens", type=int, default=40*524288, help="number of tokens to evaluate per split") -parser.add_argument("--model_tag", type=str, default=None, help="model tag for checkpoint directory") -parser.add_argument("--model_step", type=int, default=None, help="model step to load") -parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") -parser.add_argument("--hf_path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)") +parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") +parser.add_argument("--split-tokens", type=int, default=40*524288, help="number of tokens to evaluate per split") +parser.add_argument("--model-tag", type=str, default=None, help="model tag for checkpoint directory") +parser.add_argument("--model-step", type=int, default=None, help="model step to load") +parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") +parser.add_argument("--hf-path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)") args = parser.parse_args() # Load the base model and the tokenizer @@ -97,7 +97,7 @@ assert args.split_tokens % tokens_per_step == 0, "split_tokens must be divisible steps = args.split_tokens // tokens_per_step bpb_results = {} for split_name in ["train", "val"]: - loader = tokenizing_distributed_data_loader(tokenizer, args.device_batch_size, sequence_len, split_name, device=device) + loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device) with autocast_ctx: bpb = evaluate_bpb(model, loader, steps, token_bytes) print0(f"{split_name} bpb: {bpb:.4f}") diff --git a/scripts/base_train.py b/scripts/base_train.py index c7c5bba..bf4b8cf 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -8,7 +8,7 @@ or distributed as: torchrun --nproc_per_node=8 -m scripts.base_train.py If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example: -python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20 +python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 """ import os @@ -21,7 +21,7 @@ import wandb import torch from nanochat.gpt import GPT, GPTConfig -from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state +from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint @@ -36,40 +36,40 @@ parser = argparse.ArgumentParser(description="Pretrain base model") # Logging parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime -parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") +parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") # Model architecture parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") -parser.add_argument("--aspect_ratio", type=int, default=64, help="model_dim = depth * aspect_ratio") -parser.add_argument("--head_dim", type=int, default=128, help="target head dimension for attention") -parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length") -parser.add_argument("--window_pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") +parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio") +parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") +parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") +parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") # Training horizon (only one used, in order of precedence) -parser.add_argument("--num_iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") -parser.add_argument("--target_flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") -parser.add_argument("--target_param_data_ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") +parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") +parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") +parser.add_argument("--target-param-data-ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") # Optimization -parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size") -parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens") -parser.add_argument("--embedding_lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") -parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") -parser.add_argument("--weight_decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)") -parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--scalar_lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") -parser.add_argument("--adam_beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding") -parser.add_argument("--adam_beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding") -parser.add_argument("--warmup_ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") -parser.add_argument("--warmdown_ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown") -parser.add_argument("--final_lr_frac", type=float, default=0.0, help="final LR as fraction of initial LR") -parser.add_argument("--resume_from_step", type=int, default=-1, help="resume training from this step (-1 = disable)") +parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") +parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") +parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") +parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") +parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)") +parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") +parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") +parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding") +parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding") +parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") +parser.add_argument("--warmdown-ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown") +parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") +parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)") # Evaluation -parser.add_argument("--eval_every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") -parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") -parser.add_argument("--core_metric_every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)") -parser.add_argument("--core_metric_max_per_task", type=int, default=500, help="examples per task for CORE metric") -parser.add_argument("--sample_every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") -parser.add_argument("--save_every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") +parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") +parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") +parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)") +parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric") +parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") +parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") # Output -parser.add_argument("--model_tag", type=str, default=None, help="override model tag for checkpoint directory name") +parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") args = parser.parse_args() user_config = vars(args).copy() # for logging # ----------------------------------------------------------------------------- @@ -210,8 +210,8 @@ if resuming: # Initialize the DataLoaders for train/val tokens_dir = os.path.join(base_dir, "tokenized_data") dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] -train_loader = tokenizing_distributed_data_loader_with_state(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) -build_val_loader = lambda: tokenizing_distributed_data_loader(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device) +train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) +build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device) x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data # ----------------------------------------------------------------------------- @@ -395,7 +395,8 @@ while True: eta_str = f" | eta: {eta_seconds/60:.1f}m" else: eta_str = "" - print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m{eta_str}") + epoch = dataloader_state_dict["epoch"] + print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") if step % 100 == 0: log_data = { "step": step, @@ -406,6 +407,7 @@ while True: "train/dt": dt, "train/tok_per_sec": tok_per_sec, "train/mfu": mfu, + "train/epoch": epoch, } wandb_run.log(log_data) diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index ad557b9..b0697f3 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -35,32 +35,32 @@ parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K") # Logging parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime -parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") +parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") # Model loading parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from") -parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from") -parser.add_argument("--model_step", type=int, default=None, help="model step to load from") +parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") +parser.add_argument("--model-step", type=int, default=None, help="model step to load from") # Training horizon -parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs over GSM8K") +parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs over GSM8K") # Batch sizes / sampling -parser.add_argument("--device_batch_size", type=int, default=8, help="max batch size per forward pass") -parser.add_argument("--examples_per_step", type=int, default=16, help="total examples per optimization step across all ranks") -parser.add_argument("--num_samples", type=int, default=16, help="number of samples per example/question") +parser.add_argument("--device-batch-size", type=int, default=8, help="max batch size per forward pass") +parser.add_argument("--examples-per-step", type=int, default=16, help="total examples per optimization step across all ranks") +parser.add_argument("--num-samples", type=int, default=16, help="number of samples per example/question") # Generation -parser.add_argument("--max_new_tokens", type=int, default=256, help="max tokens to generate per sample") +parser.add_argument("--max-new-tokens", type=int, default=256, help="max tokens to generate per sample") parser.add_argument("--temperature", type=float, default=1.0, help="sampling temperature") -parser.add_argument("--top_k", type=int, default=50, help="top-k sampling (0 = disabled)") +parser.add_argument("--top-k", type=int, default=50, help="top-k sampling (0 = disabled)") # Optimization -parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") -parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") -parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") -parser.add_argument("--init_lr_frac", type=float, default=0.05, help="initial LR as fraction of base LR") +parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") +parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") +parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") +parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") +parser.add_argument("--init-lr-frac", type=float, default=0.05, help="initial LR as fraction of base LR") # Evaluation / checkpointing -parser.add_argument("--eval_every", type=int, default=60, help="evaluate pass@k every N steps") -parser.add_argument("--eval_examples", type=int, default=400, help="number of examples for pass@k evaluation") -parser.add_argument("--save_every", type=int, default=60, help="save checkpoint every N steps") +parser.add_argument("--eval-every", type=int, default=60, help="evaluate pass@k every N steps") +parser.add_argument("--eval-examples", type=int, default=400, help="number of examples for pass@k evaluation") +parser.add_argument("--save-every", type=int, default=60, help="save checkpoint every N steps") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 853a2bf..9277cf9 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -37,29 +37,29 @@ parser = argparse.ArgumentParser(description="Supervised finetuning for chat") # Logging parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime -parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") +parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") # Model loading parser.add_argument("--source", type=str, default="mid", help="base|mid - which checkpoint to load from") -parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from") -parser.add_argument("--model_step", type=int, default=None, help="model step to load from") +parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") +parser.add_argument("--model-step", type=int, default=None, help="model step to load from") # Training horizon -parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs") -parser.add_argument("--num_iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)") +parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs") +parser.add_argument("--num-iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)") # Batch sizes -parser.add_argument("--device_batch_size", type=int, default=4, help="per-device batch size") -parser.add_argument("--target_examples_per_step", type=int, default=32, help="target examples per optimization step") +parser.add_argument("--device-batch-size", type=int, default=4, help="per-device batch size") +parser.add_argument("--target-examples-per-step", type=int, default=32, help="target examples per optimization step") # Optimization -parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") -parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") -parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") -parser.add_argument("--init_lr_frac", type=float, default=0.02, help="initial LR as fraction of base LR") +parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") +parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") +parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") +parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") +parser.add_argument("--init-lr-frac", type=float, default=0.02, help="initial LR as fraction of base LR") # Evaluation -parser.add_argument("--eval_every", type=int, default=100, help="evaluate val loss every N steps") -parser.add_argument("--eval_steps", type=int, default=100, help="number of batches for val loss evaluation") -parser.add_argument("--eval_metrics_every", type=int, default=200, help="evaluate accuracy metrics every N steps") -parser.add_argument("--eval_metrics_max_problems", type=int, default=1024, help="max problems per metric evaluation") +parser.add_argument("--eval-every", type=int, default=100, help="evaluate val loss every N steps") +parser.add_argument("--eval-steps", type=int, default=100, help="number of batches for val loss evaluation") +parser.add_argument("--eval-metrics-every", type=int, default=200, help="evaluate accuracy metrics every N steps") +parser.add_argument("--eval-metrics-max-problems", type=int, default=1024, help="max problems per metric evaluation") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- diff --git a/scripts/mid_train.py b/scripts/mid_train.py index d684b9f..01d9f7d 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -6,11 +6,10 @@ python -m scripts.mid_train Or torchrun for training: -torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16 +torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16 """ import argparse -from collections import deque import os os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import time @@ -37,28 +36,28 @@ parser = argparse.ArgumentParser(description="Midtrain the model") # Logging parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime -parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") +parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") # Model loading -parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from") -parser.add_argument("--model_step", type=int, default=None, help="model step to load from") +parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") +parser.add_argument("--model-step", type=int, default=None, help="model step to load from") # Training horizon -parser.add_argument("--num_iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") +parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") # Batch sizes -parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length") -parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size") -parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens") +parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") +parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") +parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") # Optimization -parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") -parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") -parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") -parser.add_argument("--init_lr_frac", type=float, default=1.0, help="initial LR as fraction of base LR") +parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") +parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") +parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") +parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") +parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR") # Evaluation -parser.add_argument("--eval_every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)") -parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") +parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)") +parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") # Output -parser.add_argument("--dry_run", action="store_true", help="log to wandb but skip checkpoints/report") +parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- @@ -80,7 +79,7 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mi model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step) pretrain_batch_size = meta.get("device_batch_size", None) if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size: - print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?") + print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?") orig_model = model model = torch.compile(model, dynamic=False) depth = model.config.n_layer @@ -125,49 +124,100 @@ val_dataset = TaskMixture([ # these two global variables and update them from within the data generator. last_step = False # we will toggle this to True when we reach the end of the training dataset approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch -def mid_data_generator(split): - global last_step, approx_progress +current_epoch = 1 # track epoch for logging +def mid_data_generator_bos_bestfit(split, buffer_size=100): + """ + BOS-aligned dataloader for midtraining with bestfit-crop packing. + + Each row in the batch starts with BOS (beginning of a conversation). + Conversations are packed using best-fit algorithm to minimize cropping. + This matches the BOS-aligned approach used in pretraining. + """ + global last_step, approx_progress, current_epoch assert split in {"train", "val"}, "split must be 'train' or 'val'" dataset = train_dataset if split == "train" else val_dataset dataset_size = len(dataset) assert dataset_size > 0 - needed_tokens = args.device_batch_size * args.max_seq_len + 1 # to form one training batch of inputs,targets - token_buffer = deque() - # CUDA supports memory pinning for faster transfers between CPU and GPU: - scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda")) - cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents - it = 0 # iteration counter - while True: - # Accumulate enough tokens for one iteration before yielding - while len(token_buffer) < needed_tokens: + row_capacity = args.max_seq_len + 1 # +1 for target at last position + + # Conversation buffer: list of token lists + conv_buffer = [] + cursor = ddp_rank # Each rank processes different conversations (for fetching) + consumed = ddp_rank # Track actual consumption separately from buffering + epoch = 1 + it = 0 # iteration counter + + def refill_buffer(): + nonlocal cursor, epoch + while len(conv_buffer) < buffer_size: conversation = dataset[cursor] ids, _ = tokenizer.render_conversation(conversation) - token_buffer.extend(ids) + conv_buffer.append(ids) cursor += ddp_world_size if cursor >= dataset_size: - cursor -= dataset_size # wrap around for another epoch - if split == "train": - last_step = True # toggle last_step to True, which will terminate the training loop + cursor = cursor % dataset_size + epoch += 1 + # Note: last_step is now triggered based on consumption, not fetching + + while True: + rows = [] + for _ in range(args.device_batch_size): + row = [] + while len(row) < row_capacity: + # Ensure buffer has conversations + while len(conv_buffer) < buffer_size: + refill_buffer() + + remaining = row_capacity - len(row) + + # Find largest conversation that fits entirely + best_idx = -1 + best_len = 0 + for i, conv in enumerate(conv_buffer): + conv_len = len(conv) + if conv_len <= remaining and conv_len > best_len: + best_idx = i + best_len = conv_len + + if best_idx >= 0: + # Found a conversation that fits - use it entirely + conv = conv_buffer.pop(best_idx) + row.extend(conv) + consumed += ddp_world_size # Track actual consumption + else: + # No conversation fits - crop first conversation to fill remaining + conv = conv_buffer.pop(0) + row.extend(conv[:remaining]) + consumed += ddp_world_size # Track actual consumption + + rows.append(row[:row_capacity]) + # Stopping condition to respect num_iterations, if given it += 1 if 0 < args.num_iterations <= it and split == "train": - last_step = True # toggle last_step to True, which will terminate the training loop - # Build up inputs/targets and yield - for i in range(needed_tokens): - scratch[i] = token_buffer.popleft() - inputs_cpu = scratch[:-1].to(dtype=torch.int32) - targets_cpu = scratch[1:] - inputs = inputs_cpu.view(args.device_batch_size, args.max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True) - targets = targets_cpu.view(args.device_batch_size, args.max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True) + last_step = True + + # Update progress tracking (based on consumed, not cursor, to account for buffering) if split == "train": + current_epoch = epoch if args.num_iterations > 0: - approx_progress = it / args.num_iterations # calculate progress from the max number of iterations + approx_progress = it / args.num_iterations else: - approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset + approx_progress = consumed / dataset_size + # Trigger last_step when we've consumed enough (instead of when cursor wraps) + if consumed >= dataset_size: + last_step = True + + # Build tensors + use_cuda = device_type == "cuda" + batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda) + inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda) + targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda) + yield inputs, targets -train_loader = mid_data_generator("train") -build_val_loader = lambda: mid_data_generator("val") +train_loader = mid_data_generator_bos_bestfit("train") +build_val_loader = lambda: mid_data_generator_bos_bestfit("val") progress = 0 # will go from 0 to 1 over the course of the epoch # Learning rate scheduler @@ -285,7 +335,7 @@ while True: mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % if step > 10: total_training_time += dt # only count the time after the first 10 steps - print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") + print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m") if step % 10 == 0: wandb_run.log({ "step": step, @@ -296,6 +346,7 @@ while True: "train/dt": dt, "train/tok_per_sec": tok_per_sec, "train/mfu": mfu, + "train/epoch": current_epoch, }) # print a few more stats diff --git a/scripts/tok_train.py b/scripts/tok_train.py index 4ab995c..9c7979d 100644 --- a/scripts/tok_train.py +++ b/scripts/tok_train.py @@ -14,9 +14,9 @@ from nanochat.dataset import parquets_iter_batched # Parse command line arguments parser = argparse.ArgumentParser(description='Train a BPE tokenizer') -parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)') -parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)') -parser.add_argument('--vocab_size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)') +parser.add_argument('--max-chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)') +parser.add_argument('--doc-cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)') +parser.add_argument('--vocab-size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)') args = parser.parse_args() print(f"max_chars: {args.max_chars:,}") print(f"doc_cap: {args.doc_cap:,}") diff --git a/speedrun.sh b/speedrun.sh index 9f445b3..468086b 100644 --- a/speedrun.sh +++ b/speedrun.sh @@ -56,11 +56,11 @@ python -m nanochat.report reset # each shard is ~90MB of text (compressed), so this is about ~1GB of data on disk python -m nanochat.dataset -n 11 # Immediately also kick off downloading more shards in the background while tokenizer trains -# See comment below for why 240 is the right number here -python -m nanochat.dataset -n 240 & +# See comment below for why 370 is the right number here +python -m nanochat.dataset -n 370 & DATASET_DOWNLOAD_PID=$! # train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data -python -m scripts.tok_train --max_chars=2000000000 --vocab_size=65536 +python -m scripts.tok_train --max-chars=2000000000 --vocab-size=65536 # evaluate the tokenizer (report compression ratio etc.) python -m scripts.tok_eval @@ -71,7 +71,9 @@ python -m scripts.tok_eval # Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens. # Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars. # At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining. -# Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk. +# Round up to 240 for safety. Also, the new DataLoader wastes about 35% of tokens to cropping +# so 240 / (1 - 0.35) = 370 shards are needed. +# At ~100MB/shard, this downloads ~37GB of data to disk. # (The total number of shards available in the entire dataset is 1822.) echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID @@ -80,7 +82,7 @@ wait $DATASET_DOWNLOAD_PID NPROC_PER_NODE=8 # pretrain the d20 model -torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target_param_data_ratio=20 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target-param-data-ratio=20 --run=$WANDB_RUN # evaluate the model on a larger chunk of train/val data and draw some samples torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss # evaluate the model on CORE tasks