diff --git a/nanochat/constants.py b/nanochat/constants.py new file mode 100644 index 0000000..27885e8 --- /dev/null +++ b/nanochat/constants.py @@ -0,0 +1,28 @@ +""" +Constants used throughout the nanochat codebase. +Centralizes magic numbers for better maintainability. +""" + +# Model Architecture Constants +MODEL_ASPECT_RATIO = 64 # model_dim = depth * aspect_ratio +HEAD_DIM_TARGET = 128 # Target dimension per attention head +ROTARY_CACHE_MULTIPLIER = 10 # Precompute rotary embeddings for N times sequence length +LOGIT_SOFTCAP = 15 # Soft capping value for logits using tanh (must be > 0) + +# Memory Management +KV_CACHE_GROWTH_CHUNK = 1024 # Grow KV cache in chunks of this size (must be power of 2 for efficient bitwise rounding) + +# Training Constants +DEFAULT_WARMUP_RATIO = 0.0 # Fraction of training for learning rate warmup +DEFAULT_WARMDOWN_RATIO = 0.2 # Fraction of training for learning rate warmdown +DEFAULT_FINAL_LR_FRAC = 0.0 # Final LR as fraction of initial LR +MUON_MOMENTUM_RAMPUP_STEPS = 300 # Steps to ramp up Muon momentum +MUON_MOMENTUM_START = 0.85 # Starting momentum for Muon optimizer +MUON_MOMENTUM_END = 0.95 # Final momentum for Muon optimizer + +# Evaluation Constants +LOSS_EMA_BETA = 0.9 # Exponential moving average decay for training loss +WARMUP_IGNORE_STEPS = 10 # Ignore first N steps when calculating training time + +# Calculator Tool Constants +CALCULATOR_TIMEOUT_SECONDS = 3 # Maximum time for calculator evaluation diff --git a/nanochat/engine.py b/nanochat/engine.py index de1253a..58bfe39 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -19,6 +19,7 @@ from contextlib import contextmanager from collections import deque from nanochat.common import compute_init from nanochat.checkpoint_manager import load_model +from nanochat.constants import KV_CACHE_GROWTH_CHUNK, CALCULATOR_TIMEOUT_SECONDS # ----------------------------------------------------------------------------- # Calculator tool helpers @@ -32,7 +33,7 @@ def timeout(duration, formula): yield signal.alarm(0) -def eval_with_timeout(formula, max_time=3): +def eval_with_timeout(formula, max_time=CALCULATOR_TIMEOUT_SECONDS): try: with timeout(max_time, formula): with warnings.catch_warnings(): @@ -107,8 +108,11 @@ class KVCache: t0, t1 = self.pos, self.pos + T_add # Dynamically grow the cache if needed if t1 > self.kv_cache.size(4): - t_needed = t1 + 1024 # as much as we need plus buffer of 1024 - t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024 + chunk = KV_CACHE_GROWTH_CHUNK + assert chunk > 0 and (chunk & (chunk - 1)) == 0, \ + "KV_CACHE_GROWTH_CHUNK must be a positive power of two" + t_needed = t1 + chunk # as much as we need plus buffer + t_needed = (t_needed + chunk - 1) & ~(chunk - 1) # then round up to the nearest multiple current_shape = list(self.kv_cache.shape) current_shape[4] = t_needed self.kv_cache.resize_(current_shape) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 5a066b2..fe3f3e6 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from nanochat.common import get_dist_info, print0 from nanochat.muon import Muon, DistMuon from nanochat.adamw import DistAdamW +from nanochat.constants import ROTARY_CACHE_MULTIPLIER, LOGIT_SOFTCAP @dataclass class GPTConfig: @@ -164,7 +165,7 @@ class GPT(nn.Module): # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them, but assert fail if we ever reach that amount. # In the future we can dynamically grow the cache, for now it's fine. - self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer? + self.rotary_seq_len = config.sequence_len * ROTARY_CACHE_MULTIPLIER # 10X over-compute should be enough, TODO make nicer? head_dim = config.n_embd // config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint @@ -275,19 +276,18 @@ class GPT(nn.Module): x = norm(x) # Forward the lm_head (compute logits) - softcap = 15 if targets is not None: # training mode: compute and return the loss # TODO: experiment with Liger Kernels / chunked cross-entropy etc. logits = self.lm_head(x) - logits = softcap * torch.tanh(logits / softcap) # logits softcap + logits = LOGIT_SOFTCAP * torch.tanh(logits / LOGIT_SOFTCAP) logits = logits.float() # use tf32/fp32 for logits loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) return loss else: # inference mode: compute and return the logits logits = self.lm_head(x) - logits = softcap * torch.tanh(logits / softcap) # logits softcap + logits = LOGIT_SOFTCAP * torch.tanh(logits / LOGIT_SOFTCAP) return logits @torch.inference_mode() diff --git a/scripts/base_train.py b/scripts/base_train.py index b691ed4..b38ac81 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -22,6 +22,12 @@ from nanochat.checkpoint_manager import save_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine from scripts.base_eval import evaluate_model +from nanochat.constants import ( + MODEL_ASPECT_RATIO, HEAD_DIM_TARGET, DEFAULT_WARMUP_RATIO, + DEFAULT_WARMDOWN_RATIO, DEFAULT_FINAL_LR_FRAC, LOSS_EMA_BETA, + WARMUP_IGNORE_STEPS, MUON_MOMENTUM_RAMPUP_STEPS, + MUON_MOMENTUM_START, MUON_MOMENTUM_END +) print_banner() # ----------------------------------------------------------------------------- @@ -73,8 +79,8 @@ print0(f"Vocab size: {vocab_size:,}") # Model kwargs are derived from the desired depth of the model num_layers = depth -model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases) -num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div) +model_dim = depth * MODEL_ASPECT_RATIO # aspect ratio (usually this is varied from 64 -> 128 as model size increases) +num_heads = max(1, (model_dim + HEAD_DIM_TARGET - 1) // HEAD_DIM_TARGET) # ceiling division for target head dim num_kv_heads = num_heads # 1:1 MQA ratio print0(f"num_layers: {num_layers}") print0(f"model_dim: {model_dim}") @@ -142,9 +148,9 @@ x, y = next(train_loader) # kick off load of the very first batch of data # Learning rate scheduler # TODO: experiment with a short warmup for the AdamW params (expecting slight improvement) -warmup_ratio = 0.0 # ratio of iterations for LR warmup -warmdown_ratio = 0.2 # ratio of iterations for LR warmdown -final_lr_frac = 0.0 # final LR is this fraction of the initial LR +warmup_ratio = DEFAULT_WARMUP_RATIO # ratio of iterations for LR warmup +warmdown_ratio = DEFAULT_WARMDOWN_RATIO # ratio of iterations for LR warmdown +final_lr_frac = DEFAULT_FINAL_LR_FRAC # final LR is this fraction of the initial LR def get_lr_multiplier(it): warmup_iters = round(warmup_ratio * num_iterations) warmdown_iters = round(warmdown_ratio * num_iterations) @@ -158,15 +164,15 @@ def get_lr_multiplier(it): # Momentum scheduler for Muon optimizer def get_muon_momentum(it): - frac = min(it / 300, 1) - momentum = (1 - frac) * 0.85 + frac * 0.95 + frac = min(it / MUON_MOMENTUM_RAMPUP_STEPS, 1) + momentum = (1 - frac) * MUON_MOMENTUM_START + frac * MUON_MOMENTUM_END return momentum # ----------------------------------------------------------------------------- # Training loop min_val_bpb = float("inf") smooth_train_loss = 0 # EMA of training loss -ema_beta = 0.9 # EMA decay factor +ema_beta = LOSS_EMA_BETA # EMA decay factor total_training_time = 0 # total wall-clock time of training # note that we run +1 steps only so that we can eval and save at the end for step in range(num_iterations + 1): @@ -288,7 +294,7 @@ for step in range(num_iterations + 1): flops_per_sec = num_flops_per_token * total_batch_size / dt promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % - if step > 10: + if step > WARMUP_IGNORE_STEPS: total_training_time += dt # only count the time after the first 10 steps print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") if step % 100 == 0: diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 90ab954..3373c1a 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -27,6 +27,10 @@ from tasks.common import TaskMixture from tasks.gsm8k import GSM8K from tasks.mmlu import MMLU from tasks.smoltalk import SmolTalk +from nanochat.constants import ( + LOSS_EMA_BETA, WARMUP_IGNORE_STEPS, + MUON_MOMENTUM_RAMPUP_STEPS, MUON_MOMENTUM_START, MUON_MOMENTUM_END +) # ----------------------------------------------------------------------------- run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) @@ -146,8 +150,8 @@ def get_lr_multiplier(progress): # Momentum scheduler for Muon optimizer def get_muon_momentum(it): - frac = min(it / 300, 1) - momentum = (1 - frac) * 0.85 + frac * 0.95 + frac = min(it / MUON_MOMENTUM_RAMPUP_STEPS, 1) + momentum = (1 - frac) * MUON_MOMENTUM_START + frac * MUON_MOMENTUM_END return momentum # ----------------------------------------------------------------------------- @@ -155,7 +159,7 @@ def get_muon_momentum(it): x, y = next(train_loader) # prefetch the very first batch of data min_val_bpb = float("inf") smooth_train_loss = 0 # EMA of training loss -ema_beta = 0.9 # EMA decay factor +ema_beta = LOSS_EMA_BETA # EMA decay factor total_training_time = 0 # total wall-clock time of training step = 0 while True: @@ -252,7 +256,7 @@ while True: flops_per_sec = num_flops_per_token * total_batch_size / dt promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % - if step > 10: + if step > WARMUP_IGNORE_STEPS: total_training_time += dt # only count the time after the first 10 steps print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") if step % 10 == 0: