From 075e3bb476a286b0d5c2ff5b31850205c0fb537f Mon Sep 17 00:00:00 2001 From: Amrit Bulusu Date: Mon, 16 Mar 2026 00:25:35 -0400 Subject: [PATCH] Fix muP coord check: remove lm_head double-compensation, use float32 --- nanochat/gpt.py | 7 ++----- scripts/mup_coord_check.py | 4 +++- scripts/mup_transfer_check.py | 4 +++- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 6ab5eee..238deb9 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -219,11 +219,8 @@ class GPT(nn.Module): # Embedding and unembedding torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8) - lm_head_std = 0.001 - if self.config.mup_base_width > 0: - # muP: scale lm_head init by 1/sqrt(m_d) so raw logit magnitude is O(1) across widths. - # Without this, |logit| ~ 0.001 * sqrt(n_embd) grows as sqrt(width). - lm_head_std *= (self.config.mup_base_width / self.config.n_embd) ** 0.5 + # muP uses 0.02 for stronger initial logit signal; forward-pass scaling handles width independence + lm_head_std = 0.02 if self.config.mup_base_width > 0 else 0.001 torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=lm_head_std) # Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal) diff --git a/scripts/mup_coord_check.py b/scripts/mup_coord_check.py index c251fdc..92b3941 100644 --- a/scripts/mup_coord_check.py +++ b/scripts/mup_coord_check.py @@ -17,6 +17,8 @@ Usage: """ import argparse +import os +os.environ["NANOCHAT_DTYPE"] = "float32" import torch import torch._dynamo torch._dynamo.config.disable = True @@ -299,7 +301,7 @@ def run_coord_check(config: CoordCheckConfig, device: torch.device, model.train() for step in range(config.num_steps): - with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=(device.type == 'cuda')): + with torch.amp.autocast(device_type='cuda', dtype=torch.float32, enabled=False): loss = model(x, y) results['losses'][actual_width].append(loss.item()) diff --git a/scripts/mup_transfer_check.py b/scripts/mup_transfer_check.py index 34e911b..4ea9d58 100644 --- a/scripts/mup_transfer_check.py +++ b/scripts/mup_transfer_check.py @@ -31,6 +31,8 @@ Usage: """ import argparse +import os +os.environ["NANOCHAT_DTYPE"] = "float32" import torch import torch._dynamo torch._dynamo.config.disable = True @@ -173,7 +175,7 @@ def train_model(width: int, lr_mult: float, config: TransferCheckConfig, for step in range(config.num_steps): x, y = batches[step % num_batches] - with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=(device.type == 'cuda')): + with torch.amp.autocast(device_type='cuda', dtype=torch.float32, enabled=False): loss = model(x, y) losses.append(loss.item())