From 0c942a8c009e7eb992e652e3b17e7ec2cec97311 Mon Sep 17 00:00:00 2001 From: William Thurston Date: Sun, 22 Feb 2026 14:42:58 -0800 Subject: [PATCH] Add tie_embeddings support and configurable logging interval MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement weight tying between token embeddings and lm_head to reduce parameter count. When enabled, logits are scaled by 1/√d_model, lm_head zeroing is skipped, and optimizer groups are deduplicated. Param counting uses unique parameters while Chinchilla ratio calculation adds back the would-be lm_head size for comparability. Also adds boolean flag parsing (--flag without =value) to the configurator, an auto-derived log_every interval, and minor shell script fixes. Co-Authored-By: Claude Opus 4.6 --- dev/runmps.sh | 7 ++++-- nanochat/configurator.py | 12 +++++++++- nanochat/gpt.py | 28 ++++++++++++++++++---- scripts/base_train.py | 51 +++++++++++++++++++++++++++++++++------- 4 files changed, 82 insertions(+), 16 deletions(-) diff --git a/dev/runmps.sh b/dev/runmps.sh index 4831c71..414d77d 100755 --- a/dev/runmps.sh +++ b/dev/runmps.sh @@ -276,6 +276,9 @@ python -m scripts.tok_eval MOE_FLAGS+=("--moe_activation_denominator=$MOE_ACTIVATION_DEN") fi fi + if [ "${TIE_EMBEDDINGS:-0}" -ne 0 ]; then + MOE_FLAGS+=("--tie_embeddings=1") + fi MOE_DEBUG_INTERVAL=$MOE_DEBUG_INTERVAL python -m scripts.base_train \ --depth=$BASE_DEPTH \ @@ -298,8 +301,8 @@ python -m scripts.tok_eval fi if [ "$RUN_STAGE_EVALS" = "1" ]; then - python -m scripts.base_loss --device_batch_size=$DEVICE_BATCH --split_tokens=$EVAL_TOKENS $BASE_MODEL_TAG_FLAG - python -m scripts.base_eval --max-per-task=16 $BASE_MODEL_TAG_FLAG_HYPHEN + python -m scripts.base_loss --device_batch_size=$DEVICE_BATCH --split_tokens=$EVAL_TOKENS $BASE_MODEL_TAG_FLAG + python -m scripts.base_eval --max-per-task=16 "$BASE_MODEL_TAG_FLAG_HYPHEN" fi fi diff --git a/nanochat/configurator.py b/nanochat/configurator.py index 63d197f..0c5b1fe 100644 --- a/nanochat/configurator.py +++ b/nanochat/configurator.py @@ -25,8 +25,16 @@ def print0(s="",**kwargs): for arg in sys.argv[1:]: if '=' not in arg: + if arg.startswith('--'): + # boolean flag form: --flag sets True if default exists and is bool + key = arg[2:] + if key in globals() and isinstance(globals()[key], bool): + print0(f"Overriding: {key} = True") + globals()[key] = True + else: + raise ValueError(f"Unknown or non-bool config key for flag: {key}") + continue # assume it's the name of a config file - assert not arg.startswith('--') config_file = arg print0(f"Overriding config with {config_file}:") with open(config_file) as f: @@ -50,6 +58,8 @@ for arg in sys.argv[1:]: default_type = type(globals()[key]) if default_type is float and attempt_type in (int, float): attempt = float(attempt) + elif default_type is bool and attempt_type in (int, bool): + attempt = bool(attempt) else: assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}" # cross fingers diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 3754968..b1fe296 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -38,6 +38,7 @@ class GPTConfig: dense_layers_before_moe: int = 0 moe_granularity_target: float = 12.0 moe_activation_denominator: int = 32 + tie_embeddings: bool = False def norm(x): @@ -361,6 +362,9 @@ class GPT(nn.Module): "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), }) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + if config.tie_embeddings: + # share weights instead of keeping a duplicate parameter + self.lm_head.weight = self.transformer.wte.weight # To support meta device initialization, we init the rotary embeddings here, but it's fake # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them, but assert fail if we ever reach that amount. @@ -373,8 +377,9 @@ class GPT(nn.Module): def init_weights(self): self.apply(self._init_weights) - # zero out classifier weights - torch.nn.init.zeros_(self.lm_head.weight) + # zero out classifier weights when untied + if not self.config.tie_embeddings: + torch.nn.init.zeros_(self.lm_head.weight) # zero out c_proj weights in all blocks for block in self.transformer.h: if isinstance(block.mlp, MoEFeedForward): @@ -488,17 +493,25 @@ class GPT(nn.Module): # Separate out all parameters into 3 groups (matrix, embedding, lm_head) matrix_params = list(self.transformer.h.parameters()) embedding_params = list(self.transformer.wte.parameters()) - lm_head_params = list(self.lm_head.parameters()) - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + # When tied, lm_head shares the embedding weight, so keep its param out of optimizer groups to avoid duplicates. + lm_head_params = [] if self.config.tie_embeddings else list(self.lm_head.parameters()) + # Sanity check: all parameters should be covered by the optimizer groups (ignoring intentional sharing). + if not self.config.tie_embeddings: + param_ids = {id(p) for p in self.parameters()} + assert param_ids == {id(p) for p in matrix_params + embedding_params + lm_head_params} # Create the AdamW optimizer for the embedding and lm_head # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 if rank == 0: print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}") + # Use the embedding LR for embeddings (even when tied); head params use unembedding LR when present. + embedding_lr_effective = embedding_lr adam_groups = [ dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), - dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), + dict(params=embedding_params, lr=embedding_lr_effective * dmodel_lr_scale), ] + # drop empty groups to avoid optimizer complaints when embeddings are tied + adam_groups = [g for g in adam_groups if g["params"]] adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay) AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True) adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs) @@ -533,10 +546,13 @@ class GPT(nn.Module): # Forward the lm_head (compute logits) softcap = 15 + logit_scale = (self.config.n_embd ** -0.5) if self.config.tie_embeddings else 1.0 if targets is not None: # training mode: compute and return the loss # TODO: experiment with Liger Kernels / chunked cross-entropy etc. logits = self.lm_head(x) + if logit_scale != 1.0: + logits = logits * logit_scale logits = softcap * torch.tanh(logits / softcap) # logits softcap logits = logits.float() # use tf32/fp32 for logits loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) @@ -544,6 +560,8 @@ class GPT(nn.Module): else: # inference mode: compute and return the logits logits = self.lm_head(x) + if logit_scale != 1.0: + logits = logits * logit_scale logits = softcap * torch.tanh(logits / softcap) # logits softcap return logits diff --git a/scripts/base_train.py b/scripts/base_train.py index eb2dd49..28b20f3 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -11,6 +11,7 @@ If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Ex python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20 """ +import math import os os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import time @@ -45,6 +46,7 @@ moe_expert_ffn_mult = -1.0 # -1 => derive from granularity target (defaults to 1 dense_layers_before_moe = -1 # -1 => derive (≈10% of layers, min 1) before switching to MoE moe_granularity_target = 12.0 # Ling guidance: target granularity per layer (2*d_model/d_expert) moe_activation_denominator = 32 # derive top-k as num_experts / denominator (~3% activation) +tie_embeddings = False # tie output head to token embeddings (saves params) # Training horizon. Only one of these 3 will be used, in this order of precedence. num_iterations = -1 # explicit number of steps of the optimization (-1 = disable) target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable) @@ -61,11 +63,12 @@ warmup_ratio = 0.0 # ratio of iterations for LR warmup warmdown_ratio = 0.2 # ratio of iterations for LR warmdown final_lr_frac = 0.0 # final LR is this fraction of the initial LR # Evaluation -eval_every = 250 # every how many steps to evaluate the model for val bpb +eval_every = 0 # <=0 => auto (~1% of total training steps) evaluation cadence eval_tokens = 20*524288 # number of tokens to evaluate val loss on core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable) core_metric_max_per_task = 500 # examples per task in estimating the core metric sample_every = 2000 # every how many steps to sample from the model +log_every = 0 # <=0 => auto (~1% of total training steps) logging cadence for train/MoE metrics # Output model_tag = "" # optionally override the model tag for the output checkpoint directory name checkpoint_every_steps = 0 # save intermediate checkpoints every N optimization steps (0 = disable) @@ -192,24 +195,47 @@ model_config_kwargs = dict( dense_layers_before_moe=dense_layers_before_moe, moe_granularity_target=granularity_target, moe_activation_denominator=activation_denom, + tie_embeddings=tie_embeddings, ) with torch.device("meta"): model_config = GPTConfig(**model_config_kwargs) model = GPT(model_config) model.to_empty(device=device) model.init_weights() +# Re-assert tying at runtime (belt-and-suspenders) and emit a sanity check. +if tie_embeddings: + model.lm_head.weight = model.transformer.wte.weight + assert model.lm_head.weight is model.transformer.wte.weight, "tie_embeddings requested but weights not shared" orig_model = model # original, uncompiled model, for saving raw model state_dict dense_like_flops = model.estimate_flops() active_flops_per_token, dense_ref_flops = model.estimate_moe_active_flops() num_flops_per_token = active_flops_per_token model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through -num_params = sum(p.numel() for p in model.parameters()) + +def _count_unique_params(m): + seen = set() + total = 0 + for p in m.parameters(): + pid = id(p) + if pid in seen: + continue + seen.add(pid) + total += p.numel() + return total + +num_params = _count_unique_params(model) print0(f"Number of parameters: {num_params:,}") print0(f"Estimated FLOPs per token (dense-like): {dense_like_flops:e}") if active_flops_per_token != dense_like_flops: print0(f"Estimated FLOPs per token (MoE active): {active_flops_per_token:e}") print0(f"Estimated FLOPs per token (dense reference): {dense_ref_flops:e}") +# If embeddings are tied, keep the data:parameter ratio comparable to untied models by +# counting the would-be lm_head parameters toward the ratio calculation. +num_params_for_ratio = _count_unique_params(model) +if tie_embeddings: + num_params_for_ratio += vocab_size * model_dim + user_config.update({ "moe_num_experts": moe_num_experts, "moe_num_shared_experts": moe_num_shared_experts, @@ -246,18 +272,27 @@ elif target_flops > 0: print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}") elif target_param_data_ratio > 0: # calculate the number of iterations from the target param data ratio - target_tokens = target_param_data_ratio * num_params + target_tokens = target_param_data_ratio * num_params_for_ratio num_iterations = target_tokens // total_batch_size print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") else: raise ValueError("No training horizon specified") total_tokens = total_batch_size * num_iterations print0(f"Total number of training tokens: {total_tokens:,}") -print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20 +print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params_for_ratio:.2f}") # Chinchilla is ~20 print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") -if eval_every <= 0: - eval_every = max(1, num_iterations // 100) - print0(f"Auto-setting eval_every to {eval_every} (~1% of training)") +def _resolve_progress_interval(name, configured_value, total_steps, default_frac): + if configured_value > 0: + return max(1, int(configured_value)) + interval = max(1, math.ceil(total_steps * default_frac)) + print0(f"Auto-setting {name} to {interval} (~{default_frac * 100:.2f}% of training)") + return interval + +eval_every = _resolve_progress_interval("eval_every", eval_every, num_iterations, 0.01) +user_config["eval_every"] = eval_every + +log_every_steps = _resolve_progress_interval("log_every", log_every, num_iterations, 0.01) +user_config["log_every"] = log_every_steps sequences_per_step = max(1, total_batch_size // max_seq_len) checkpoint_every_steps = int(checkpoint_every_steps) @@ -477,7 +512,7 @@ for step in range(num_iterations + 1): if step > 10: total_training_time += dt # only count the time after the first 10 steps print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec (micro): {tok_per_sec:,} | tok/sec (global): {global_tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") - if step % 100 == 0: + if step % log_every_steps == 0: log_payload = { "step": step, "total_training_flops": flops_so_far,