mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-02 16:49:50 +00:00
Refactor constants in training scripts and engine to improve configurability. Replace hardcoded values with constants for KV cache growth, rotary cache multiplier, and learning rate parameters. This enhances maintainability and allows for easier adjustments in future iterations.
This commit is contained in:
parent
fae3aca951
commit
d366f7e07d
28
nanochat/constants.py
Normal file
28
nanochat/constants.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
"""
|
||||
Constants used throughout the nanochat codebase.
|
||||
Centralizes magic numbers for better maintainability.
|
||||
"""
|
||||
|
||||
# Model Architecture Constants
|
||||
MODEL_ASPECT_RATIO = 64 # model_dim = depth * aspect_ratio
|
||||
HEAD_DIM_TARGET = 128 # Target dimension per attention head
|
||||
ROTARY_CACHE_MULTIPLIER = 10 # Precompute rotary embeddings for N times sequence length
|
||||
LOGIT_SOFTCAP = 15 # Soft capping value for logits using tanh (must be > 0)
|
||||
|
||||
# Memory Management
|
||||
KV_CACHE_GROWTH_CHUNK = 1024 # Grow KV cache in chunks of this size (must be power of 2 for efficient bitwise rounding)
|
||||
|
||||
# Training Constants
|
||||
DEFAULT_WARMUP_RATIO = 0.0 # Fraction of training for learning rate warmup
|
||||
DEFAULT_WARMDOWN_RATIO = 0.2 # Fraction of training for learning rate warmdown
|
||||
DEFAULT_FINAL_LR_FRAC = 0.0 # Final LR as fraction of initial LR
|
||||
MUON_MOMENTUM_RAMPUP_STEPS = 300 # Steps to ramp up Muon momentum
|
||||
MUON_MOMENTUM_START = 0.85 # Starting momentum for Muon optimizer
|
||||
MUON_MOMENTUM_END = 0.95 # Final momentum for Muon optimizer
|
||||
|
||||
# Evaluation Constants
|
||||
LOSS_EMA_BETA = 0.9 # Exponential moving average decay for training loss
|
||||
WARMUP_IGNORE_STEPS = 10 # Ignore first N steps when calculating training time
|
||||
|
||||
# Calculator Tool Constants
|
||||
CALCULATOR_TIMEOUT_SECONDS = 3 # Maximum time for calculator evaluation
|
||||
|
|
@ -19,6 +19,7 @@ from contextlib import contextmanager
|
|||
from collections import deque
|
||||
from nanochat.common import compute_init
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.constants import KV_CACHE_GROWTH_CHUNK, CALCULATOR_TIMEOUT_SECONDS
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Calculator tool helpers
|
||||
|
|
@ -32,7 +33,7 @@ def timeout(duration, formula):
|
|||
yield
|
||||
signal.alarm(0)
|
||||
|
||||
def eval_with_timeout(formula, max_time=3):
|
||||
def eval_with_timeout(formula, max_time=CALCULATOR_TIMEOUT_SECONDS):
|
||||
try:
|
||||
with timeout(max_time, formula):
|
||||
with warnings.catch_warnings():
|
||||
|
|
@ -107,8 +108,11 @@ class KVCache:
|
|||
t0, t1 = self.pos, self.pos + T_add
|
||||
# Dynamically grow the cache if needed
|
||||
if t1 > self.kv_cache.size(4):
|
||||
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
|
||||
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
|
||||
chunk = KV_CACHE_GROWTH_CHUNK
|
||||
assert chunk > 0 and (chunk & (chunk - 1)) == 0, \
|
||||
"KV_CACHE_GROWTH_CHUNK must be a positive power of two"
|
||||
t_needed = t1 + chunk # as much as we need plus buffer
|
||||
t_needed = (t_needed + chunk - 1) & ~(chunk - 1) # then round up to the nearest multiple
|
||||
current_shape = list(self.kv_cache.shape)
|
||||
current_shape[4] = t_needed
|
||||
self.kv_cache.resize_(current_shape)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import torch.nn.functional as F
|
|||
from nanochat.common import get_dist_info, print0
|
||||
from nanochat.muon import Muon, DistMuon
|
||||
from nanochat.adamw import DistAdamW
|
||||
from nanochat.constants import ROTARY_CACHE_MULTIPLIER, LOGIT_SOFTCAP
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
|
|
@ -164,7 +165,7 @@ class GPT(nn.Module):
|
|||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them, but assert fail if we ever reach that amount.
|
||||
# In the future we can dynamically grow the cache, for now it's fine.
|
||||
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
|
||||
self.rotary_seq_len = config.sequence_len * ROTARY_CACHE_MULTIPLIER # 10X over-compute should be enough, TODO make nicer?
|
||||
head_dim = config.n_embd // config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
|
|
@ -275,19 +276,18 @@ class GPT(nn.Module):
|
|||
x = norm(x)
|
||||
|
||||
# Forward the lm_head (compute logits)
|
||||
softcap = 15
|
||||
if targets is not None:
|
||||
# training mode: compute and return the loss
|
||||
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
|
||||
logits = self.lm_head(x)
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
logits = LOGIT_SOFTCAP * torch.tanh(logits / LOGIT_SOFTCAP)
|
||||
logits = logits.float() # use tf32/fp32 for logits
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
||||
return loss
|
||||
else:
|
||||
# inference mode: compute and return the logits
|
||||
logits = self.lm_head(x)
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
logits = LOGIT_SOFTCAP * torch.tanh(logits / LOGIT_SOFTCAP)
|
||||
return logits
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
|
|||
|
|
@ -22,6 +22,12 @@ from nanochat.checkpoint_manager import save_checkpoint
|
|||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from scripts.base_eval import evaluate_model
|
||||
from nanochat.constants import (
|
||||
MODEL_ASPECT_RATIO, HEAD_DIM_TARGET, DEFAULT_WARMUP_RATIO,
|
||||
DEFAULT_WARMDOWN_RATIO, DEFAULT_FINAL_LR_FRAC, LOSS_EMA_BETA,
|
||||
WARMUP_IGNORE_STEPS, MUON_MOMENTUM_RAMPUP_STEPS,
|
||||
MUON_MOMENTUM_START, MUON_MOMENTUM_END
|
||||
)
|
||||
print_banner()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -73,8 +79,8 @@ print0(f"Vocab size: {vocab_size:,}")
|
|||
|
||||
# Model kwargs are derived from the desired depth of the model
|
||||
num_layers = depth
|
||||
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
|
||||
model_dim = depth * MODEL_ASPECT_RATIO # aspect ratio (usually this is varied from 64 -> 128 as model size increases)
|
||||
num_heads = max(1, (model_dim + HEAD_DIM_TARGET - 1) // HEAD_DIM_TARGET) # ceiling division for target head dim
|
||||
num_kv_heads = num_heads # 1:1 MQA ratio
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim}")
|
||||
|
|
@ -142,9 +148,9 @@ x, y = next(train_loader) # kick off load of the very first batch of data
|
|||
|
||||
# Learning rate scheduler
|
||||
# TODO: experiment with a short warmup for the AdamW params (expecting slight improvement)
|
||||
warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
||||
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||
warmup_ratio = DEFAULT_WARMUP_RATIO # ratio of iterations for LR warmup
|
||||
warmdown_ratio = DEFAULT_WARMDOWN_RATIO # ratio of iterations for LR warmdown
|
||||
final_lr_frac = DEFAULT_FINAL_LR_FRAC # final LR is this fraction of the initial LR
|
||||
def get_lr_multiplier(it):
|
||||
warmup_iters = round(warmup_ratio * num_iterations)
|
||||
warmdown_iters = round(warmdown_ratio * num_iterations)
|
||||
|
|
@ -158,15 +164,15 @@ def get_lr_multiplier(it):
|
|||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
def get_muon_momentum(it):
|
||||
frac = min(it / 300, 1)
|
||||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||||
frac = min(it / MUON_MOMENTUM_RAMPUP_STEPS, 1)
|
||||
momentum = (1 - frac) * MUON_MOMENTUM_START + frac * MUON_MOMENTUM_END
|
||||
return momentum
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
min_val_bpb = float("inf")
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
ema_beta = 0.9 # EMA decay factor
|
||||
ema_beta = LOSS_EMA_BETA # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
# note that we run +1 steps only so that we can eval and save at the end
|
||||
for step in range(num_iterations + 1):
|
||||
|
|
@ -288,7 +294,7 @@ for step in range(num_iterations + 1):
|
|||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
if step > WARMUP_IGNORE_STEPS:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
if step % 100 == 0:
|
||||
|
|
|
|||
|
|
@ -27,6 +27,10 @@ from tasks.common import TaskMixture
|
|||
from tasks.gsm8k import GSM8K
|
||||
from tasks.mmlu import MMLU
|
||||
from tasks.smoltalk import SmolTalk
|
||||
from nanochat.constants import (
|
||||
LOSS_EMA_BETA, WARMUP_IGNORE_STEPS,
|
||||
MUON_MOMENTUM_RAMPUP_STEPS, MUON_MOMENTUM_START, MUON_MOMENTUM_END
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
|
|
@ -146,8 +150,8 @@ def get_lr_multiplier(progress):
|
|||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
def get_muon_momentum(it):
|
||||
frac = min(it / 300, 1)
|
||||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||||
frac = min(it / MUON_MOMENTUM_RAMPUP_STEPS, 1)
|
||||
momentum = (1 - frac) * MUON_MOMENTUM_START + frac * MUON_MOMENTUM_END
|
||||
return momentum
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -155,7 +159,7 @@ def get_muon_momentum(it):
|
|||
x, y = next(train_loader) # prefetch the very first batch of data
|
||||
min_val_bpb = float("inf")
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
ema_beta = 0.9 # EMA decay factor
|
||||
ema_beta = LOSS_EMA_BETA # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
step = 0
|
||||
while True:
|
||||
|
|
@ -252,7 +256,7 @@ while True:
|
|||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
if step > WARMUP_IGNORE_STEPS:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
if step % 10 == 0:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user