From d366f7e07de48be25f544cd48015f8a5a52616f2 Mon Sep 17 00:00:00 2001 From: dssjon Date: Wed, 15 Oct 2025 14:58:06 -0700 Subject: [PATCH] Refactor constants in training scripts and engine to improve configurability. Replace hardcoded values with constants for KV cache growth, rotary cache multiplier, and learning rate parameters. This enhances maintainability and allows for easier adjustments in future iterations. --- nanochat/constants.py | 28 ++++++++++++++++++++++++++++ nanochat/engine.py | 10 +++++++--- nanochat/gpt.py | 8 ++++---- scripts/base_train.py | 24 +++++++++++++++--------- scripts/mid_train.py | 12 ++++++++---- 5 files changed, 62 insertions(+), 20 deletions(-) create mode 100644 nanochat/constants.py diff --git a/nanochat/constants.py b/nanochat/constants.py new file mode 100644 index 0000000..27885e8 --- /dev/null +++ b/nanochat/constants.py @@ -0,0 +1,28 @@ +""" +Constants used throughout the nanochat codebase. +Centralizes magic numbers for better maintainability. +""" + +# Model Architecture Constants +MODEL_ASPECT_RATIO = 64 # model_dim = depth * aspect_ratio +HEAD_DIM_TARGET = 128 # Target dimension per attention head +ROTARY_CACHE_MULTIPLIER = 10 # Precompute rotary embeddings for N times sequence length +LOGIT_SOFTCAP = 15 # Soft capping value for logits using tanh (must be > 0) + +# Memory Management +KV_CACHE_GROWTH_CHUNK = 1024 # Grow KV cache in chunks of this size (must be power of 2 for efficient bitwise rounding) + +# Training Constants +DEFAULT_WARMUP_RATIO = 0.0 # Fraction of training for learning rate warmup +DEFAULT_WARMDOWN_RATIO = 0.2 # Fraction of training for learning rate warmdown +DEFAULT_FINAL_LR_FRAC = 0.0 # Final LR as fraction of initial LR +MUON_MOMENTUM_RAMPUP_STEPS = 300 # Steps to ramp up Muon momentum +MUON_MOMENTUM_START = 0.85 # Starting momentum for Muon optimizer +MUON_MOMENTUM_END = 0.95 # Final momentum for Muon optimizer + +# Evaluation Constants +LOSS_EMA_BETA = 0.9 # Exponential moving average decay for training loss +WARMUP_IGNORE_STEPS = 10 # Ignore first N steps when calculating training time + +# Calculator Tool Constants +CALCULATOR_TIMEOUT_SECONDS = 3 # Maximum time for calculator evaluation diff --git a/nanochat/engine.py b/nanochat/engine.py index de1253a..58bfe39 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -19,6 +19,7 @@ from contextlib import contextmanager from collections import deque from nanochat.common import compute_init from nanochat.checkpoint_manager import load_model +from nanochat.constants import KV_CACHE_GROWTH_CHUNK, CALCULATOR_TIMEOUT_SECONDS # ----------------------------------------------------------------------------- # Calculator tool helpers @@ -32,7 +33,7 @@ def timeout(duration, formula): yield signal.alarm(0) -def eval_with_timeout(formula, max_time=3): +def eval_with_timeout(formula, max_time=CALCULATOR_TIMEOUT_SECONDS): try: with timeout(max_time, formula): with warnings.catch_warnings(): @@ -107,8 +108,11 @@ class KVCache: t0, t1 = self.pos, self.pos + T_add # Dynamically grow the cache if needed if t1 > self.kv_cache.size(4): - t_needed = t1 + 1024 # as much as we need plus buffer of 1024 - t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024 + chunk = KV_CACHE_GROWTH_CHUNK + assert chunk > 0 and (chunk & (chunk - 1)) == 0, \ + "KV_CACHE_GROWTH_CHUNK must be a positive power of two" + t_needed = t1 + chunk # as much as we need plus buffer + t_needed = (t_needed + chunk - 1) & ~(chunk - 1) # then round up to the nearest multiple current_shape = list(self.kv_cache.shape) current_shape[4] = t_needed self.kv_cache.resize_(current_shape) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 5a066b2..fe3f3e6 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from nanochat.common import get_dist_info, print0 from nanochat.muon import Muon, DistMuon from nanochat.adamw import DistAdamW +from nanochat.constants import ROTARY_CACHE_MULTIPLIER, LOGIT_SOFTCAP @dataclass class GPTConfig: @@ -164,7 +165,7 @@ class GPT(nn.Module): # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them, but assert fail if we ever reach that amount. # In the future we can dynamically grow the cache, for now it's fine. - self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer? + self.rotary_seq_len = config.sequence_len * ROTARY_CACHE_MULTIPLIER # 10X over-compute should be enough, TODO make nicer? head_dim = config.n_embd // config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint @@ -275,19 +276,18 @@ class GPT(nn.Module): x = norm(x) # Forward the lm_head (compute logits) - softcap = 15 if targets is not None: # training mode: compute and return the loss # TODO: experiment with Liger Kernels / chunked cross-entropy etc. logits = self.lm_head(x) - logits = softcap * torch.tanh(logits / softcap) # logits softcap + logits = LOGIT_SOFTCAP * torch.tanh(logits / LOGIT_SOFTCAP) logits = logits.float() # use tf32/fp32 for logits loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) return loss else: # inference mode: compute and return the logits logits = self.lm_head(x) - logits = softcap * torch.tanh(logits / softcap) # logits softcap + logits = LOGIT_SOFTCAP * torch.tanh(logits / LOGIT_SOFTCAP) return logits @torch.inference_mode() diff --git a/scripts/base_train.py b/scripts/base_train.py index b691ed4..b38ac81 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -22,6 +22,12 @@ from nanochat.checkpoint_manager import save_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine from scripts.base_eval import evaluate_model +from nanochat.constants import ( + MODEL_ASPECT_RATIO, HEAD_DIM_TARGET, DEFAULT_WARMUP_RATIO, + DEFAULT_WARMDOWN_RATIO, DEFAULT_FINAL_LR_FRAC, LOSS_EMA_BETA, + WARMUP_IGNORE_STEPS, MUON_MOMENTUM_RAMPUP_STEPS, + MUON_MOMENTUM_START, MUON_MOMENTUM_END +) print_banner() # ----------------------------------------------------------------------------- @@ -73,8 +79,8 @@ print0(f"Vocab size: {vocab_size:,}") # Model kwargs are derived from the desired depth of the model num_layers = depth -model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases) -num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div) +model_dim = depth * MODEL_ASPECT_RATIO # aspect ratio (usually this is varied from 64 -> 128 as model size increases) +num_heads = max(1, (model_dim + HEAD_DIM_TARGET - 1) // HEAD_DIM_TARGET) # ceiling division for target head dim num_kv_heads = num_heads # 1:1 MQA ratio print0(f"num_layers: {num_layers}") print0(f"model_dim: {model_dim}") @@ -142,9 +148,9 @@ x, y = next(train_loader) # kick off load of the very first batch of data # Learning rate scheduler # TODO: experiment with a short warmup for the AdamW params (expecting slight improvement) -warmup_ratio = 0.0 # ratio of iterations for LR warmup -warmdown_ratio = 0.2 # ratio of iterations for LR warmdown -final_lr_frac = 0.0 # final LR is this fraction of the initial LR +warmup_ratio = DEFAULT_WARMUP_RATIO # ratio of iterations for LR warmup +warmdown_ratio = DEFAULT_WARMDOWN_RATIO # ratio of iterations for LR warmdown +final_lr_frac = DEFAULT_FINAL_LR_FRAC # final LR is this fraction of the initial LR def get_lr_multiplier(it): warmup_iters = round(warmup_ratio * num_iterations) warmdown_iters = round(warmdown_ratio * num_iterations) @@ -158,15 +164,15 @@ def get_lr_multiplier(it): # Momentum scheduler for Muon optimizer def get_muon_momentum(it): - frac = min(it / 300, 1) - momentum = (1 - frac) * 0.85 + frac * 0.95 + frac = min(it / MUON_MOMENTUM_RAMPUP_STEPS, 1) + momentum = (1 - frac) * MUON_MOMENTUM_START + frac * MUON_MOMENTUM_END return momentum # ----------------------------------------------------------------------------- # Training loop min_val_bpb = float("inf") smooth_train_loss = 0 # EMA of training loss -ema_beta = 0.9 # EMA decay factor +ema_beta = LOSS_EMA_BETA # EMA decay factor total_training_time = 0 # total wall-clock time of training # note that we run +1 steps only so that we can eval and save at the end for step in range(num_iterations + 1): @@ -288,7 +294,7 @@ for step in range(num_iterations + 1): flops_per_sec = num_flops_per_token * total_batch_size / dt promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % - if step > 10: + if step > WARMUP_IGNORE_STEPS: total_training_time += dt # only count the time after the first 10 steps print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") if step % 100 == 0: diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 90ab954..3373c1a 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -27,6 +27,10 @@ from tasks.common import TaskMixture from tasks.gsm8k import GSM8K from tasks.mmlu import MMLU from tasks.smoltalk import SmolTalk +from nanochat.constants import ( + LOSS_EMA_BETA, WARMUP_IGNORE_STEPS, + MUON_MOMENTUM_RAMPUP_STEPS, MUON_MOMENTUM_START, MUON_MOMENTUM_END +) # ----------------------------------------------------------------------------- run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) @@ -146,8 +150,8 @@ def get_lr_multiplier(progress): # Momentum scheduler for Muon optimizer def get_muon_momentum(it): - frac = min(it / 300, 1) - momentum = (1 - frac) * 0.85 + frac * 0.95 + frac = min(it / MUON_MOMENTUM_RAMPUP_STEPS, 1) + momentum = (1 - frac) * MUON_MOMENTUM_START + frac * MUON_MOMENTUM_END return momentum # ----------------------------------------------------------------------------- @@ -155,7 +159,7 @@ def get_muon_momentum(it): x, y = next(train_loader) # prefetch the very first batch of data min_val_bpb = float("inf") smooth_train_loss = 0 # EMA of training loss -ema_beta = 0.9 # EMA decay factor +ema_beta = LOSS_EMA_BETA # EMA decay factor total_training_time = 0 # total wall-clock time of training step = 0 while True: @@ -252,7 +256,7 @@ while True: flops_per_sec = num_flops_per_token * total_batch_size / dt promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % - if step > 10: + if step > WARMUP_IGNORE_STEPS: total_training_time += dt # only count the time after the first 10 steps print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") if step % 10 == 0: