From f8ff0439b9b9192399deb1ed8a09874152b4a407 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Fri, 6 Mar 2026 11:03:00 +0100 Subject: [PATCH 01/13] two more small typos --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 077fd9c..6be1109 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train This uses wandb (run name "d12"), only runs the CORE metric on last step, and it doesn't sample and save intermediate checkpoints. I like to change something in the code, re-run a d12 (or a d16 etc) and see if it helped, in an iteration loop. To see if a run helps, I like to monitor the wandb plots for: 1. `val_bpb` (validation loss in vocab-size-invariant units of bits per byte) as a function of `step`, `total_training_time` and `total_training_flops`. -2. `core_metric` (the DCLM CORE socre) +2. `core_metric` (the DCLM CORE score) 3. VRAM utilization, `train/mfu` (Model FLOPS utilization), `train/tok_per_sec` (training throughput) See an example [here](https://github.com/karpathy/nanochat/pull/498#issuecomment-3850720044). @@ -101,7 +101,7 @@ NANOCHAT_DTYPE=bfloat16 torchrun --nproc_per_node=8 -m scripts.base_train # for How it works: model weights are stored in fp32 (for optimizer precision), but our custom `Linear` layer casts them to `COMPUTE_DTYPE` during the forward pass. Embeddings are stored directly in `COMPUTE_DTYPE` to save memory. This gives us the same mixed-precision benefit as autocast but with full explicit control over what runs in which precision. -Note: `float16` training automatically enables a `GradScaler` in `base_train.py` to prevent gradient underflow. SFT suppors this too but RL currently does not. Inference in fp16 works fine everywhere. +Note: `float16` training automatically enables a `GradScaler` in `base_train.py` to prevent gradient underflow. SFT supports this too but RL currently does not. Inference in fp16 works fine everywhere. ## Guides From d96558bcb0dc11b546bebff79bc0f56fa944c362 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Tue, 10 Mar 2026 09:57:30 +0100 Subject: [PATCH 02/13] fix heading, cf #622 --- .claude/skills/read-arxiv-paper/SKILL.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.claude/skills/read-arxiv-paper/SKILL.md b/.claude/skills/read-arxiv-paper/SKILL.md index 6a9cda7..0a1b131 100644 --- a/.claude/skills/read-arxiv-paper/SKILL.md +++ b/.claude/skills/read-arxiv-paper/SKILL.md @@ -33,7 +33,7 @@ Every latex source usually has an entrypoint, such as `main.tex` or something li Once you've found the entrypoint, Read the contents and then recurse through all other relevant source files to read the paper. -#### Part 6: Report +### Part 6: Report Once you've read the paper, produce a summary of the paper into a markdown file at `./knowledge/summary_{tag}.md`. Notice that 1) use the local knowledge directory here (it's easier for me to open and reference here), not in `~/.cache`, and 2) generate some reasonable `tag` like e.g. `conditional_memory` or whatever seems appropriate given the paper. Probably make sure that the tag doesn't exist yet so you're not overwriting files. From 2bb93b2ae4c8a4afc6a3d5741c934f0e0976b4c2 Mon Sep 17 00:00:00 2001 From: 2bitbit <180839704+2bitbit@users.noreply.github.com> Date: Thu, 12 Mar 2026 17:03:26 +0800 Subject: [PATCH 03/13] fix: correct minor typos in help text, README, and comments --- README.md | 2 +- scripts/chat_sft.py | 2 +- scripts/tok_train.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 1fed675..ea4132e 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ A few more notes: - The code will run just fine on the Ampere 8XA100 GPU node as well, but a bit slower. - All code will run just fine on even a single GPU by omitting `torchrun`, and will produce ~identical results (code will automatically switch to gradient accumulation), but you'll have to wait 8 times longer. -- If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device_batch_size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative. +- If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device-batch-size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative. - Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't personally exercised all of these code paths so there might be sharp edges. ## Research diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index c1adbb6..a1cca8b 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -177,7 +177,7 @@ val_dataset = TaskMixture([ SmolTalk(split="test"), # 24K rows in test set MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios -]) # total: 24K + 14K + 1.32K ~= 39K rows +]) # total: 24K + 5.2K + 0.42K ~= 29.6K rows # DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len) # A big problem is that we don't know the final num_iterations in advance. So we create # these two global variables and update them from within the data generator. diff --git a/scripts/tok_train.py b/scripts/tok_train.py index 480e0e1..90495b1 100644 --- a/scripts/tok_train.py +++ b/scripts/tok_train.py @@ -14,7 +14,7 @@ from nanochat.dataset import parquets_iter_batched # Parse command line arguments parser = argparse.ArgumentParser(description='Train a BPE tokenizer') -parser.add_argument('--max-chars', type=int, default=2_000_000_000, help='Maximum characters to train on (default: 10B)') +parser.add_argument('--max-chars', type=int, default=2_000_000_000, help='Maximum characters to train on (default: 2B)') parser.add_argument('--doc-cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)') parser.add_argument('--vocab-size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)') args = parser.parse_args() From a641b6ca966fdabe81d8c30f25b287f3de9039a3 Mon Sep 17 00:00:00 2001 From: Mathieu Lacage Date: Fri, 13 Mar 2026 13:19:10 +0100 Subject: [PATCH 04/13] MMLU main split is named auxiliary_train, not train --- scripts/chat_sft.py | 2 +- tasks/common.py | 4 ++-- tasks/mmlu.py | 9 ++------- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index c1adbb6..ab886a7 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -166,7 +166,7 @@ train_tasks = [ SmolTalk(split="train"), # 460K rows of general conversations CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these - *[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch + *[MMLU(subset="all", split="auxiliary_train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch *[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) diff --git a/tasks/common.py b/tasks/common.py index 2d6ddd8..211ff3f 100644 --- a/tasks/common.py +++ b/tasks/common.py @@ -135,12 +135,12 @@ if __name__ == "__main__": # very lightweight test of slicing from tasks.mmlu import MMLU - ds = MMLU(subset="auxiliary_train", split="train") + ds = MMLU(subset="all", split="auxiliary_train") print("Length of MMLU: ", len(ds)) ex = ds[5] print("5th example: ", ex) - ds = MMLU(subset="auxiliary_train", split="train", start=5, stop=10) + ds = MMLU(subset="all", split="auxiliary_train", start=5, stop=10) print("Length of sliced MMLU[5:10]: ", len(ds)) print("0th example of sliced MMLU: ", ds[0]) diff --git a/tasks/mmlu.py b/tasks/mmlu.py index 3ba2254..4721f9f 100644 --- a/tasks/mmlu.py +++ b/tasks/mmlu.py @@ -13,16 +13,11 @@ class MMLU(Task): def __init__(self, subset, split, **kwargs): super().__init__(**kwargs) - assert subset in ["all", "auxiliary_train"], f"subset {subset} must be all|auxiliary_train" - assert split in ["train", "validation", "dev", "test"], f"split {split} must be train|validation|dev|test" - if subset == "auxiliary_train": - assert split == "train", "auxiliary_train must be split into train" + assert subset in ["all"], f"subset {subset} must be all" + assert split in ["auxiliary_train", "validation", "dev", "test"], f"split {split} must be auxiliary_train|validation|dev|test" self.subset = subset self.split = split self.ds = load_dataset("cais/mmlu", subset, split=split).shuffle(seed=42) - if subset == "auxiliary_train": - # I don't understand why but the auxiliary_train rows have some weird additional 'train' wrapper - self.ds = self.ds.map(lambda row: row['train'], remove_columns=['train']) @property def eval_type(self): From 1052d25d454847a4bbf2cb85cbee250471535814 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Fri, 13 Mar 2026 13:46:16 +0100 Subject: [PATCH 05/13] we only need to wait 2h now! --- dev/LEADERBOARD.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/LEADERBOARD.md b/dev/LEADERBOARD.md index 556ec3c..6fdeaa3 100644 --- a/dev/LEADERBOARD.md +++ b/dev/LEADERBOARD.md @@ -36,7 +36,7 @@ Note that: - `target-param-data-ratio=8.25` controls the training horizon, which is determined in the script by taking the number of non-embedding model parameters and simply multiplying by this number. The current optimal Tokens:Params ratio can be seen in the defaults of the `base_train.py` script (it is 10.5). 10.5 would produce the *compute optimal* model given the currently measured scaling laws. However, GPT-2 capability is currently somewhere in between a d24 and d26. So to reach it exactly, we want to either overtrain d24 or undertrain d26. In this particular example, I am choosing to slightly undertrain a d26. Note that odd depths (e.g. d25) are not super recommended to use because the math around the transformer sizing and its head dimensions doesn't come out neatly. - `--fp8` turns on fp8 training. If your GPU does not support fp8, you can leave this out and the code will simply train in bf16. bf16 is higher precision than fp8, so you can actually expect that you might be able to do fewer steps (lower the `target-param-data-ratio`) to achieve the same capability. -Once you kick off the run, you wait ~3 hours and then at the end you'll see something like: +Once you kick off the run, you wait ~2 hours and then at the end you'll see something like: ``` wandb: Run summary: From bd6e9c8d5fb1d02f43bb4bb0c837736183662b39 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Sun, 15 Mar 2026 22:18:18 +0100 Subject: [PATCH 06/13] fix numbering --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9c09cc3..fa0cd23 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Presently, the main focus of development is on tuning the pretraining stage, whi | 3 | 2.76 | 0.74645 | 0.2602 | bump total batch size to 1M tokens | Feb 5 2026 | 2c062aa | @karpathy | | 4 | 2.02 | 0.71854 | 0.2571 | change dataset to NVIDIA ClimbMix | Mar 4 2026 | 324e69c | @ddudek @karpathy | | 5 | 1.80 | 0.71808 | 0.2690 | autoresearch [round 1](https://x.com/karpathy/status/2031135152349524125) | Mar 9 2026 | 6ed7d1d | @karpathy | -| 5 | 1.65 | 0.71800 | 0.2626 | autoresearch round 2 | Mar 14 2026 | a825e63 | @karpathy | +| 6 | 1.65 | 0.71800 | 0.2626 | autoresearch round 2 | Mar 14 2026 | a825e63 | @karpathy | The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $43,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 2 hours is ~$48). From 1f9e42a85588c34be86e4cb30db5488b0f01f4c2 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Sun, 15 Mar 2026 22:27:18 +0100 Subject: [PATCH 07/13] two more typos, from PR 645 --- .claude/skills/read-arxiv-paper/SKILL.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.claude/skills/read-arxiv-paper/SKILL.md b/.claude/skills/read-arxiv-paper/SKILL.md index 0a1b131..cebee1b 100644 --- a/.claude/skills/read-arxiv-paper/SKILL.md +++ b/.claude/skills/read-arxiv-paper/SKILL.md @@ -1,6 +1,6 @@ --- name: read-arxiv-paper -description: Use this skill when when asked to read an arxiv paper given an arxiv URL +description: Use this skill when asked to read an arxiv paper given an arxiv URL --- You will be given a URL of an arxiv paper, for example: @@ -37,4 +37,4 @@ Once you've found the entrypoint, Read the contents and then recurse through all Once you've read the paper, produce a summary of the paper into a markdown file at `./knowledge/summary_{tag}.md`. Notice that 1) use the local knowledge directory here (it's easier for me to open and reference here), not in `~/.cache`, and 2) generate some reasonable `tag` like e.g. `conditional_memory` or whatever seems appropriate given the paper. Probably make sure that the tag doesn't exist yet so you're not overwriting files. -As for the summary itself, remember that you're processing this paper within the context of the nanochat repository, so most often we we will be interested in how to apply the paper and its lessons to the nanochat project. Therefore, you should feel free to "remind yourself" of the related nanochat code by reading the relevant parts, and then explicitly make the connection of how this paper might relate to nanochat or what are things we might be inspired about or try. +As for the summary itself, remember that you're processing this paper within the context of the nanochat repository, so most often we will be interested in how to apply the paper and its lessons to the nanochat project. Therefore, you should feel free to "remind yourself" of the related nanochat code by reading the relevant parts, and then explicitly make the connection of how this paper might relate to nanochat or what are things we might be inspired about or try. From 51f42a4406ccd5223f945edbbd6deefba14e3f97 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Sun, 15 Mar 2026 22:29:27 +0100 Subject: [PATCH 08/13] ~1.5h :-) --- dev/LEADERBOARD.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/LEADERBOARD.md b/dev/LEADERBOARD.md index 097c394..c3fa8cd 100644 --- a/dev/LEADERBOARD.md +++ b/dev/LEADERBOARD.md @@ -36,7 +36,7 @@ Note that: - `target-param-data-ratio=8.25` controls the training horizon, which is determined in the script by taking the number of non-embedding model parameters and simply multiplying by this number. The current optimal Tokens:Params ratio can be seen in the defaults of the `base_train.py` script (it is 10.5). 10.5 would produce the *compute optimal* model given the currently measured scaling laws. However, GPT-2 capability is currently somewhere in between a d24 and d26. So to reach it exactly, we want to either overtrain d24 or undertrain d26. In this particular example, I am choosing to slightly undertrain a d26. Note that odd depths (e.g. d25) are not super recommended to use because the math around the transformer sizing and its head dimensions doesn't come out neatly. - `--fp8` turns on fp8 training. If your GPU does not support fp8, you can leave this out and the code will simply train in bf16. bf16 is higher precision than fp8, so you can actually expect that you might be able to do fewer steps (lower the `target-param-data-ratio`) to achieve the same capability. -Once you kick off the run, you wait ~2 hours and then at the end you'll see something like: +Once you kick off the run, you wait ~1.5 hours and then at the end you'll see something like: ``` wandb: Run summary: From c16db281ffe816966e8a4e1ef79b00d4b627228a Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 24 Mar 2026 19:25:34 +0000 Subject: [PATCH 09/13] fix small bug with params logging and batch size --- runs/scaling_laws.sh | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/runs/scaling_laws.sh b/runs/scaling_laws.sh index 212e675..0e0b600 100644 --- a/runs/scaling_laws.sh +++ b/runs/scaling_laws.sh @@ -8,7 +8,7 @@ FLOPS_BUDGETS=( 4.64e18 1e19 ) -DEPTHS=(8 10 12 14 16 18 20) +DEPTHS=(10 12 14 16 18 20) NPROC_PER_NODE="${NPROC_PER_NODE:-8}" WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}" @@ -60,6 +60,15 @@ for flops in "${FLOPS_BUDGETS[@]}"; do # Unique tag for this run TAG="scaling_${flops}_d${d}" + # Reduce --device-batch-size to avoid OOM at larger depths + if [ $d -ge 28 ]; then + DEVICE_BATCH_SIZE_ARG="--device-batch-size=8" + elif [ $d -ge 20 ]; then + DEVICE_BATCH_SIZE_ARG="--device-batch-size=16" + else + DEVICE_BATCH_SIZE_ARG="--device-batch-size=32" + fi + # Record start time START_TIME=$(date +%s) @@ -77,6 +86,7 @@ for flops in "${FLOPS_BUDGETS[@]}"; do --core-metric-max-per-task=-1 \ --sample-every=-1 \ --save-every=-1 \ + $DEVICE_BATCH_SIZE_ARG \ 2>&1 | tee "$RESULTS_DIR/${TAG}_train.log" END_TIME=$(date +%s) @@ -96,8 +106,9 @@ for flops in "${FLOPS_BUDGETS[@]}"; do PARAMS_TOTAL=$(grep "^total " "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',') NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',') - # Calculate tokens trained (iterations * batch_size, default 524288) - TOKENS_TRAINED=$((NUM_ITERS * 524288)) + # Extract actual batch size from log (auto-computed, varies by model size) + BATCH_SIZE=$(grep "Total batch size" "$LOG_FILE" | tail -1 | grep -oP 'Total batch size \K[\d,]+' | tr -d ',') + TOKENS_TRAINED=$((NUM_ITERS * BATCH_SIZE)) # Model dim MODEL_DIM=$((d * 64)) # Val BPB from final eval From 1cd94d768f14ac4a20249eedc89df568f3f4d50b Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 24 Mar 2026 19:25:50 +0000 Subject: [PATCH 10/13] bump D:N ratio to 12 per recent scaling laws re-run --- scripts/base_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index 86aa770..c7683c9 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -55,7 +55,7 @@ parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding # Training horizon (only one used, in order of precedence) parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") -parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") +parser.add_argument("--target-param-data-ratio", type=float, default=12, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") # Optimization parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.") parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)") From 4e1694cc957075591fda8adb4a1f34b2f47fdea1 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 24 Mar 2026 22:13:13 +0000 Subject: [PATCH 11/13] bunch of ideas tried from openai/parameter-golf, all negative results for nanochat --- dev/LOG.md | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/dev/LOG.md b/dev/LOG.md index fd5c3c7..dddfcb0 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,59 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-03-24: Parameter-Golf Ideas Sweep (Negative) + +Reviewed `openai/parameter-golf` for small/simple ideas that might transfer to nanochat pretraining without bloating the codebase. Cached notes are in `knowledge/parameter_golf.md`. + +### Rationale + +The parameter-golf leaderboard is a useful source of: + +- tiny architecture tweaks +- short-run optimizer/schedule tricks +- Muon-related systems ideas + +But much of that repo is optimized for a very different objective: + +- fit in a 16MB artifact +- train in under 10 minutes on 8xH100 +- evaluate on compression / bpb + +So only a small subset of ideas looked worth trying in nanochat. + +### Ideas Tried + +**1. LeakyReLU(0.5)^2** +- Replaced `relu^2` in the MLP with `leaky_relu(x, 0.5)^2` +- **Result:** Slightly better per-step quality, but slightly slower. Net worse on wall clock. + +**2. Partial RoPE** +- Applied rotary embeddings to only the first quarter of each head dimension +- **Result:** Slightly worse. + +**3. LN Scale** +- Multiplied each block's normalized input by `1/sqrt(layer_idx+1)` before attention and MLP +- **Result:** Did not help. + +**4. Orthogonal init** +- Switched the non-zero transformer matrices to orthogonal init while preserving zero-init output projections +- **Result:** Did not help. + +**5. XSA (Exclusive Self Attention)** +- Implemented XSA on the deepest 3 non-VE layers only, so it projected against the plain `v` path rather than `v + VE` +- **Result:** Slightly better step quality but not wall clock. Not worth the extra compute in the hot attention path. + +### Notes + +- EMA/SWA had already been tried earlier (I skipped recording it) and did not help. +- Bigram hash embeddings had already been explored much earlier and did help somewhat, but the added parameters / VRAM / complexity were not justified at larger scale. See the Jan 27-28 entries above. + +### Conclusion + +This pass did not find any cheap parameter-golf transfer that clearly improves nanochat on the metric that matters: wall clock time to capability. + +--- + ## 2026-03-04: Remove autocast, explicit dtype management, fp16 GradScaler Replaced `torch.amp.autocast` throughout the codebase with explicit dtype management via a single `COMPUTE_DTYPE` global. Also added fp16 training support with GradScaler. From c0dbf1f3fff10ef9d1a50e14a6188e04506251b6 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 25 Mar 2026 20:19:14 +0000 Subject: [PATCH 12/13] use COMPUTE_DTYPE-aware cast in Muon polar express step The bf16 cast is intentional for speed on Hopper+ GPUs, but should be skipped on other platforms rather than blindly applied. fp16 is unstable here due to its limited exponent range, and fp32 platforms don't benefit from the cast. Now: bf16 when COMPUTE_DTYPE is bf16, no cast otherwise. Inspired by PR #667. Co-Authored-By: Claude Opus 4.6 (1M context) --- nanochat/optim.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nanochat/optim.py b/nanochat/optim.py index 0ee2e27..56e85e1 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -10,6 +10,7 @@ Further contributions from @karpathy and @chrisjmccormick. import torch import torch.distributed as dist from torch import Tensor +from nanochat.common import COMPUTE_DTYPE # ----------------------------------------------------------------------------- """ @@ -112,7 +113,8 @@ def muon_step_fused( g = stacked_grads.lerp_(momentum_buffer, momentum) # Polar express - X = g.bfloat16() + # Cast to bf16 for speed when available; skip cast otherwise (fp16 is unstable here due to limited exponent range) + X = g.bfloat16() if COMPUTE_DTYPE == torch.bfloat16 else g X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6) if g.size(-2) > g.size(-1): # Tall matrix for a, b, c in polar_express_coeffs[:ns_steps]: From 47e983eea7513d545fb6becc8b32756b6c43d06b Mon Sep 17 00:00:00 2001 From: RoomWithOutRoof <166608075+Jah-yee@users.noreply.github.com> Date: Thu, 26 Mar 2026 05:24:57 +0800 Subject: [PATCH 13/13] fix: use meta device in disable_fp8 to avoid VRAM spike (#616) When swapping Float8Linear to Linear in disable_fp8 context manager, using device=fp8_module.weight.device directly allocates new tensors on GPU, causing unnecessary VRAM spike (~1GB for large models). This fix uses device='meta' to avoid physical memory allocation, then swaps in the weight tensor reference. This eliminates the unnecessary VRAM spike during evaluation phase. Fixes issue #592 Co-authored-by: RoomWithOutRoof --- scripts/base_train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index c7683c9..a161c47 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -218,12 +218,13 @@ def disable_fp8(model): return # Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype) + # Use device="meta" to avoid VRAM spike - the weight tensor will be swapped in afterwards for parent, attr_name, fp8_module in fp8_locations: linear = Linear( fp8_module.in_features, fp8_module.out_features, bias=fp8_module.bias is not None, - device=fp8_module.weight.device, + device="meta", # Use meta device to avoid unnecessary VRAM allocation dtype=fp8_module.weight.dtype, ) linear.weight = fp8_module.weight # share, don't copy