Fix muP coord check: remove lm_head double-compensation, use float32

This commit is contained in:
Amrit Bulusu 2026-03-16 00:25:35 -04:00
parent e3bef727b4
commit 075e3bb476
3 changed files with 8 additions and 7 deletions

View File

@ -219,11 +219,8 @@ class GPT(nn.Module):
# Embedding and unembedding
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8)
lm_head_std = 0.001
if self.config.mup_base_width > 0:
# muP: scale lm_head init by 1/sqrt(m_d) so raw logit magnitude is O(1) across widths.
# Without this, |logit| ~ 0.001 * sqrt(n_embd) grows as sqrt(width).
lm_head_std *= (self.config.mup_base_width / self.config.n_embd) ** 0.5
# muP uses 0.02 for stronger initial logit signal; forward-pass scaling handles width independence
lm_head_std = 0.02 if self.config.mup_base_width > 0 else 0.001
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=lm_head_std)
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)

View File

@ -17,6 +17,8 @@ Usage:
"""
import argparse
import os
os.environ["NANOCHAT_DTYPE"] = "float32"
import torch
import torch._dynamo
torch._dynamo.config.disable = True
@ -299,7 +301,7 @@ def run_coord_check(config: CoordCheckConfig, device: torch.device,
model.train()
for step in range(config.num_steps):
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=(device.type == 'cuda')):
with torch.amp.autocast(device_type='cuda', dtype=torch.float32, enabled=False):
loss = model(x, y)
results['losses'][actual_width].append(loss.item())

View File

@ -31,6 +31,8 @@ Usage:
"""
import argparse
import os
os.environ["NANOCHAT_DTYPE"] = "float32"
import torch
import torch._dynamo
torch._dynamo.config.disable = True
@ -173,7 +175,7 @@ def train_model(width: int, lr_mult: float, config: TransferCheckConfig,
for step in range(config.num_steps):
x, y = batches[step % num_batches]
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=(device.type == 'cuda')):
with torch.amp.autocast(device_type='cuda', dtype=torch.float32, enabled=False):
loss = model(x, y)
losses.append(loss.item())