Refactor constants in training scripts and engine to improve configurability. Replace hardcoded values with constants for KV cache growth, rotary cache multiplier, and learning rate parameters. This enhances maintainability and allows for easier adjustments in future iterations.

This commit is contained in:
dssjon 2025-10-15 14:58:06 -07:00
parent fae3aca951
commit d366f7e07d
5 changed files with 62 additions and 20 deletions

28
nanochat/constants.py Normal file
View File

@ -0,0 +1,28 @@
"""
Constants used throughout the nanochat codebase.
Centralizes magic numbers for better maintainability.
"""
# Model Architecture Constants
MODEL_ASPECT_RATIO = 64 # model_dim = depth * aspect_ratio
HEAD_DIM_TARGET = 128 # Target dimension per attention head
ROTARY_CACHE_MULTIPLIER = 10 # Precompute rotary embeddings for N times sequence length
LOGIT_SOFTCAP = 15 # Soft capping value for logits using tanh (must be > 0)
# Memory Management
KV_CACHE_GROWTH_CHUNK = 1024 # Grow KV cache in chunks of this size (must be power of 2 for efficient bitwise rounding)
# Training Constants
DEFAULT_WARMUP_RATIO = 0.0 # Fraction of training for learning rate warmup
DEFAULT_WARMDOWN_RATIO = 0.2 # Fraction of training for learning rate warmdown
DEFAULT_FINAL_LR_FRAC = 0.0 # Final LR as fraction of initial LR
MUON_MOMENTUM_RAMPUP_STEPS = 300 # Steps to ramp up Muon momentum
MUON_MOMENTUM_START = 0.85 # Starting momentum for Muon optimizer
MUON_MOMENTUM_END = 0.95 # Final momentum for Muon optimizer
# Evaluation Constants
LOSS_EMA_BETA = 0.9 # Exponential moving average decay for training loss
WARMUP_IGNORE_STEPS = 10 # Ignore first N steps when calculating training time
# Calculator Tool Constants
CALCULATOR_TIMEOUT_SECONDS = 3 # Maximum time for calculator evaluation

View File

@ -19,6 +19,7 @@ from contextlib import contextmanager
from collections import deque
from nanochat.common import compute_init
from nanochat.checkpoint_manager import load_model
from nanochat.constants import KV_CACHE_GROWTH_CHUNK, CALCULATOR_TIMEOUT_SECONDS
# -----------------------------------------------------------------------------
# Calculator tool helpers
@ -32,7 +33,7 @@ def timeout(duration, formula):
yield
signal.alarm(0)
def eval_with_timeout(formula, max_time=3):
def eval_with_timeout(formula, max_time=CALCULATOR_TIMEOUT_SECONDS):
try:
with timeout(max_time, formula):
with warnings.catch_warnings():
@ -107,8 +108,11 @@ class KVCache:
t0, t1 = self.pos, self.pos + T_add
# Dynamically grow the cache if needed
if t1 > self.kv_cache.size(4):
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
chunk = KV_CACHE_GROWTH_CHUNK
assert chunk > 0 and (chunk & (chunk - 1)) == 0, \
"KV_CACHE_GROWTH_CHUNK must be a positive power of two"
t_needed = t1 + chunk # as much as we need plus buffer
t_needed = (t_needed + chunk - 1) & ~(chunk - 1) # then round up to the nearest multiple
current_shape = list(self.kv_cache.shape)
current_shape[4] = t_needed
self.kv_cache.resize_(current_shape)

View File

@ -22,6 +22,7 @@ import torch.nn.functional as F
from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
from nanochat.constants import ROTARY_CACHE_MULTIPLIER, LOGIT_SOFTCAP
@dataclass
class GPTConfig:
@ -164,7 +165,7 @@ class GPT(nn.Module):
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them, but assert fail if we ever reach that amount.
# In the future we can dynamically grow the cache, for now it's fine.
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
self.rotary_seq_len = config.sequence_len * ROTARY_CACHE_MULTIPLIER # 10X over-compute should be enough, TODO make nicer?
head_dim = config.n_embd // config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
@ -275,19 +276,18 @@ class GPT(nn.Module):
x = norm(x)
# Forward the lm_head (compute logits)
softcap = 15
if targets is not None:
# training mode: compute and return the loss
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap) # logits softcap
logits = LOGIT_SOFTCAP * torch.tanh(logits / LOGIT_SOFTCAP)
logits = logits.float() # use tf32/fp32 for logits
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
return loss
else:
# inference mode: compute and return the logits
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap) # logits softcap
logits = LOGIT_SOFTCAP * torch.tanh(logits / LOGIT_SOFTCAP)
return logits
@torch.inference_mode()

View File

@ -22,6 +22,12 @@ from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from scripts.base_eval import evaluate_model
from nanochat.constants import (
MODEL_ASPECT_RATIO, HEAD_DIM_TARGET, DEFAULT_WARMUP_RATIO,
DEFAULT_WARMDOWN_RATIO, DEFAULT_FINAL_LR_FRAC, LOSS_EMA_BETA,
WARMUP_IGNORE_STEPS, MUON_MOMENTUM_RAMPUP_STEPS,
MUON_MOMENTUM_START, MUON_MOMENTUM_END
)
print_banner()
# -----------------------------------------------------------------------------
@ -73,8 +79,8 @@ print0(f"Vocab size: {vocab_size:,}")
# Model kwargs are derived from the desired depth of the model
num_layers = depth
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
model_dim = depth * MODEL_ASPECT_RATIO # aspect ratio (usually this is varied from 64 -> 128 as model size increases)
num_heads = max(1, (model_dim + HEAD_DIM_TARGET - 1) // HEAD_DIM_TARGET) # ceiling division for target head dim
num_kv_heads = num_heads # 1:1 MQA ratio
print0(f"num_layers: {num_layers}")
print0(f"model_dim: {model_dim}")
@ -142,9 +148,9 @@ x, y = next(train_loader) # kick off load of the very first batch of data
# Learning rate scheduler
# TODO: experiment with a short warmup for the AdamW params (expecting slight improvement)
warmup_ratio = 0.0 # ratio of iterations for LR warmup
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
warmup_ratio = DEFAULT_WARMUP_RATIO # ratio of iterations for LR warmup
warmdown_ratio = DEFAULT_WARMDOWN_RATIO # ratio of iterations for LR warmdown
final_lr_frac = DEFAULT_FINAL_LR_FRAC # final LR is this fraction of the initial LR
def get_lr_multiplier(it):
warmup_iters = round(warmup_ratio * num_iterations)
warmdown_iters = round(warmdown_ratio * num_iterations)
@ -158,15 +164,15 @@ def get_lr_multiplier(it):
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95
frac = min(it / MUON_MOMENTUM_RAMPUP_STEPS, 1)
momentum = (1 - frac) * MUON_MOMENTUM_START + frac * MUON_MOMENTUM_END
return momentum
# -----------------------------------------------------------------------------
# Training loop
min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
ema_beta = LOSS_EMA_BETA # EMA decay factor
total_training_time = 0 # total wall-clock time of training
# note that we run +1 steps only so that we can eval and save at the end
for step in range(num_iterations + 1):
@ -288,7 +294,7 @@ for step in range(num_iterations + 1):
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
if step > WARMUP_IGNORE_STEPS:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
if step % 100 == 0:

View File

@ -27,6 +27,10 @@ from tasks.common import TaskMixture
from tasks.gsm8k import GSM8K
from tasks.mmlu import MMLU
from tasks.smoltalk import SmolTalk
from nanochat.constants import (
LOSS_EMA_BETA, WARMUP_IGNORE_STEPS,
MUON_MOMENTUM_RAMPUP_STEPS, MUON_MOMENTUM_START, MUON_MOMENTUM_END
)
# -----------------------------------------------------------------------------
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
@ -146,8 +150,8 @@ def get_lr_multiplier(progress):
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95
frac = min(it / MUON_MOMENTUM_RAMPUP_STEPS, 1)
momentum = (1 - frac) * MUON_MOMENTUM_START + frac * MUON_MOMENTUM_END
return momentum
# -----------------------------------------------------------------------------
@ -155,7 +159,7 @@ def get_muon_momentum(it):
x, y = next(train_loader) # prefetch the very first batch of data
min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
ema_beta = LOSS_EMA_BETA # EMA decay factor
total_training_time = 0 # total wall-clock time of training
step = 0
while True:
@ -252,7 +256,7 @@ while True:
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
if step > WARMUP_IGNORE_STEPS:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
if step % 10 == 0: