also add base_train change example for how to swap LinearFP8

This commit is contained in:
Andrej Karpathy 2026-01-13 17:08:10 +00:00
parent a6382a6ce8
commit 69b1ed245e

View File

@ -22,6 +22,7 @@ import torch.nn.functional as F
from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
from nanochat.fp8_static import LinearFP8
# Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar)
import os
@ -159,7 +160,7 @@ class GPT(nn.Module):
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
self.lm_head = LinearFP8(config.n_embd, padded_vocab_size, bias=False, x_scale=100/448, w_scale=1.6/448, monitor=False)
# Per-layer learnable scalars (inspired by modded-nanogpt)
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)