mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-20 18:34:14 +00:00
also add base_train change example for how to swap LinearFP8
This commit is contained in:
parent
a6382a6ce8
commit
69b1ed245e
|
|
@ -22,6 +22,7 @@ import torch.nn.functional as F
|
|||
from nanochat.common import get_dist_info, print0
|
||||
from nanochat.muon import Muon, DistMuon
|
||||
from nanochat.adamw import DistAdamW
|
||||
from nanochat.fp8_static import LinearFP8
|
||||
|
||||
# Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar)
|
||||
import os
|
||||
|
|
@ -159,7 +160,7 @@ class GPT(nn.Module):
|
|||
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||
})
|
||||
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
|
||||
self.lm_head = LinearFP8(config.n_embd, padded_vocab_size, bias=False, x_scale=100/448, w_scale=1.6/448, monitor=False)
|
||||
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
||||
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
||||
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user