diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 81ccb0c..88c6687 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from nanochat.common import get_dist_info, print0 from nanochat.muon import Muon, DistMuon from nanochat.adamw import DistAdamW +from nanochat.fp8_static import LinearFP8 # Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar) import os @@ -159,7 +160,7 @@ class GPT(nn.Module): "wte": nn.Embedding(padded_vocab_size, config.n_embd), "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), }) - self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False) + self.lm_head = LinearFP8(config.n_embd, padded_vocab_size, bias=False, x_scale=100/448, w_scale=1.6/448, monitor=False) # Per-layer learnable scalars (inspired by modded-nanogpt) # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral) # x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)