mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 21:25:21 +00:00
Fix muP coord check: remove lm_head double-compensation, use float32
This commit is contained in:
parent
e3bef727b4
commit
075e3bb476
|
|
@ -219,11 +219,8 @@ class GPT(nn.Module):
|
|||
|
||||
# Embedding and unembedding
|
||||
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8)
|
||||
lm_head_std = 0.001
|
||||
if self.config.mup_base_width > 0:
|
||||
# muP: scale lm_head init by 1/sqrt(m_d) so raw logit magnitude is O(1) across widths.
|
||||
# Without this, |logit| ~ 0.001 * sqrt(n_embd) grows as sqrt(width).
|
||||
lm_head_std *= (self.config.mup_base_width / self.config.n_embd) ** 0.5
|
||||
# muP uses 0.02 for stronger initial logit signal; forward-pass scaling handles width independence
|
||||
lm_head_std = 0.02 if self.config.mup_base_width > 0 else 0.001
|
||||
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=lm_head_std)
|
||||
|
||||
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ Usage:
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
os.environ["NANOCHAT_DTYPE"] = "float32"
|
||||
import torch
|
||||
import torch._dynamo
|
||||
torch._dynamo.config.disable = True
|
||||
|
|
@ -299,7 +301,7 @@ def run_coord_check(config: CoordCheckConfig, device: torch.device,
|
|||
model.train()
|
||||
|
||||
for step in range(config.num_steps):
|
||||
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=(device.type == 'cuda')):
|
||||
with torch.amp.autocast(device_type='cuda', dtype=torch.float32, enabled=False):
|
||||
loss = model(x, y)
|
||||
|
||||
results['losses'][actual_width].append(loss.item())
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ Usage:
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
os.environ["NANOCHAT_DTYPE"] = "float32"
|
||||
import torch
|
||||
import torch._dynamo
|
||||
torch._dynamo.config.disable = True
|
||||
|
|
@ -173,7 +175,7 @@ def train_model(width: int, lr_mult: float, config: TransferCheckConfig,
|
|||
|
||||
for step in range(config.num_steps):
|
||||
x, y = batches[step % num_batches]
|
||||
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=(device.type == 'cuda')):
|
||||
with torch.amp.autocast(device_type='cuda', dtype=torch.float32, enabled=False):
|
||||
loss = model(x, y)
|
||||
|
||||
losses.append(loss.item())
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user