Add tie_embeddings support and configurable logging interval

Implement weight tying between token embeddings and lm_head to reduce
parameter count. When enabled, logits are scaled by 1/√d_model, lm_head
zeroing is skipped, and optimizer groups are deduplicated. Param counting
uses unique parameters while Chinchilla ratio calculation adds back the
would-be lm_head size for comparability.

Also adds boolean flag parsing (--flag without =value) to the configurator,
an auto-derived log_every interval, and minor shell script fixes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
William Thurston 2026-02-22 14:42:58 -08:00
parent 5f13389568
commit 0c942a8c00
4 changed files with 82 additions and 16 deletions

View File

@ -276,6 +276,9 @@ python -m scripts.tok_eval
MOE_FLAGS+=("--moe_activation_denominator=$MOE_ACTIVATION_DEN")
fi
fi
if [ "${TIE_EMBEDDINGS:-0}" -ne 0 ]; then
MOE_FLAGS+=("--tie_embeddings=1")
fi
MOE_DEBUG_INTERVAL=$MOE_DEBUG_INTERVAL python -m scripts.base_train \
--depth=$BASE_DEPTH \
@ -298,8 +301,8 @@ python -m scripts.tok_eval
fi
if [ "$RUN_STAGE_EVALS" = "1" ]; then
python -m scripts.base_loss --device_batch_size=$DEVICE_BATCH --split_tokens=$EVAL_TOKENS $BASE_MODEL_TAG_FLAG
python -m scripts.base_eval --max-per-task=16 $BASE_MODEL_TAG_FLAG_HYPHEN
python -m scripts.base_loss --device_batch_size=$DEVICE_BATCH --split_tokens=$EVAL_TOKENS $BASE_MODEL_TAG_FLAG
python -m scripts.base_eval --max-per-task=16 "$BASE_MODEL_TAG_FLAG_HYPHEN"
fi
fi

View File

@ -25,8 +25,16 @@ def print0(s="",**kwargs):
for arg in sys.argv[1:]:
if '=' not in arg:
if arg.startswith('--'):
# boolean flag form: --flag sets True if default exists and is bool
key = arg[2:]
if key in globals() and isinstance(globals()[key], bool):
print0(f"Overriding: {key} = True")
globals()[key] = True
else:
raise ValueError(f"Unknown or non-bool config key for flag: {key}")
continue
# assume it's the name of a config file
assert not arg.startswith('--')
config_file = arg
print0(f"Overriding config with {config_file}:")
with open(config_file) as f:
@ -50,6 +58,8 @@ for arg in sys.argv[1:]:
default_type = type(globals()[key])
if default_type is float and attempt_type in (int, float):
attempt = float(attempt)
elif default_type is bool and attempt_type in (int, bool):
attempt = bool(attempt)
else:
assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}"
# cross fingers

View File

@ -38,6 +38,7 @@ class GPTConfig:
dense_layers_before_moe: int = 0
moe_granularity_target: float = 12.0
moe_activation_denominator: int = 32
tie_embeddings: bool = False
def norm(x):
@ -361,6 +362,9 @@ class GPT(nn.Module):
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
if config.tie_embeddings:
# share weights instead of keeping a duplicate parameter
self.lm_head.weight = self.transformer.wte.weight
# To support meta device initialization, we init the rotary embeddings here, but it's fake
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them, but assert fail if we ever reach that amount.
@ -373,8 +377,9 @@ class GPT(nn.Module):
def init_weights(self):
self.apply(self._init_weights)
# zero out classifier weights
torch.nn.init.zeros_(self.lm_head.weight)
# zero out classifier weights when untied
if not self.config.tie_embeddings:
torch.nn.init.zeros_(self.lm_head.weight)
# zero out c_proj weights in all blocks
for block in self.transformer.h:
if isinstance(block.mlp, MoEFeedForward):
@ -488,17 +493,25 @@ class GPT(nn.Module):
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
matrix_params = list(self.transformer.h.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
# When tied, lm_head shares the embedding weight, so keep its param out of optimizer groups to avoid duplicates.
lm_head_params = [] if self.config.tie_embeddings else list(self.lm_head.parameters())
# Sanity check: all parameters should be covered by the optimizer groups (ignoring intentional sharing).
if not self.config.tie_embeddings:
param_ids = {id(p) for p in self.parameters()}
assert param_ids == {id(p) for p in matrix_params + embedding_params + lm_head_params}
# Create the AdamW optimizer for the embedding and lm_head
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
if rank == 0:
print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
# Use the embedding LR for embeddings (even when tied); head params use unembedding LR when present.
embedding_lr_effective = embedding_lr
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr_effective * dmodel_lr_scale),
]
# drop empty groups to avoid optimizer complaints when embeddings are tied
adam_groups = [g for g in adam_groups if g["params"]]
adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
@ -533,10 +546,13 @@ class GPT(nn.Module):
# Forward the lm_head (compute logits)
softcap = 15
logit_scale = (self.config.n_embd ** -0.5) if self.config.tie_embeddings else 1.0
if targets is not None:
# training mode: compute and return the loss
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
logits = self.lm_head(x)
if logit_scale != 1.0:
logits = logits * logit_scale
logits = softcap * torch.tanh(logits / softcap) # logits softcap
logits = logits.float() # use tf32/fp32 for logits
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
@ -544,6 +560,8 @@ class GPT(nn.Module):
else:
# inference mode: compute and return the logits
logits = self.lm_head(x)
if logit_scale != 1.0:
logits = logits * logit_scale
logits = softcap * torch.tanh(logits / softcap) # logits softcap
return logits

View File

@ -11,6 +11,7 @@ If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Ex
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
"""
import math
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time
@ -45,6 +46,7 @@ moe_expert_ffn_mult = -1.0 # -1 => derive from granularity target (defaults to 1
dense_layers_before_moe = -1 # -1 => derive (≈10% of layers, min 1) before switching to MoE
moe_granularity_target = 12.0 # Ling guidance: target granularity per layer (2*d_model/d_expert)
moe_activation_denominator = 32 # derive top-k as num_experts / denominator (~3% activation)
tie_embeddings = False # tie output head to token embeddings (saves params)
# Training horizon. Only one of these 3 will be used, in this order of precedence.
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
@ -61,11 +63,12 @@ warmup_ratio = 0.0 # ratio of iterations for LR warmup
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
# Evaluation
eval_every = 250 # every how many steps to evaluate the model for val bpb
eval_every = 0 # <=0 => auto (~1% of total training steps) evaluation cadence
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
core_metric_max_per_task = 500 # examples per task in estimating the core metric
sample_every = 2000 # every how many steps to sample from the model
log_every = 0 # <=0 => auto (~1% of total training steps) logging cadence for train/MoE metrics
# Output
model_tag = "" # optionally override the model tag for the output checkpoint directory name
checkpoint_every_steps = 0 # save intermediate checkpoints every N optimization steps (0 = disable)
@ -192,24 +195,47 @@ model_config_kwargs = dict(
dense_layers_before_moe=dense_layers_before_moe,
moe_granularity_target=granularity_target,
moe_activation_denominator=activation_denom,
tie_embeddings=tie_embeddings,
)
with torch.device("meta"):
model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config)
model.to_empty(device=device)
model.init_weights()
# Re-assert tying at runtime (belt-and-suspenders) and emit a sanity check.
if tie_embeddings:
model.lm_head.weight = model.transformer.wte.weight
assert model.lm_head.weight is model.transformer.wte.weight, "tie_embeddings requested but weights not shared"
orig_model = model # original, uncompiled model, for saving raw model state_dict
dense_like_flops = model.estimate_flops()
active_flops_per_token, dense_ref_flops = model.estimate_moe_active_flops()
num_flops_per_token = active_flops_per_token
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
num_params = sum(p.numel() for p in model.parameters())
def _count_unique_params(m):
seen = set()
total = 0
for p in m.parameters():
pid = id(p)
if pid in seen:
continue
seen.add(pid)
total += p.numel()
return total
num_params = _count_unique_params(model)
print0(f"Number of parameters: {num_params:,}")
print0(f"Estimated FLOPs per token (dense-like): {dense_like_flops:e}")
if active_flops_per_token != dense_like_flops:
print0(f"Estimated FLOPs per token (MoE active): {active_flops_per_token:e}")
print0(f"Estimated FLOPs per token (dense reference): {dense_ref_flops:e}")
# If embeddings are tied, keep the data:parameter ratio comparable to untied models by
# counting the would-be lm_head parameters toward the ratio calculation.
num_params_for_ratio = _count_unique_params(model)
if tie_embeddings:
num_params_for_ratio += vocab_size * model_dim
user_config.update({
"moe_num_experts": moe_num_experts,
"moe_num_shared_experts": moe_num_shared_experts,
@ -246,18 +272,27 @@ elif target_flops > 0:
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
elif target_param_data_ratio > 0:
# calculate the number of iterations from the target param data ratio
target_tokens = target_param_data_ratio * num_params
target_tokens = target_param_data_ratio * num_params_for_ratio
num_iterations = target_tokens // total_batch_size
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
else:
raise ValueError("No training horizon specified")
total_tokens = total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params_for_ratio:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
if eval_every <= 0:
eval_every = max(1, num_iterations // 100)
print0(f"Auto-setting eval_every to {eval_every} (~1% of training)")
def _resolve_progress_interval(name, configured_value, total_steps, default_frac):
if configured_value > 0:
return max(1, int(configured_value))
interval = max(1, math.ceil(total_steps * default_frac))
print0(f"Auto-setting {name} to {interval} (~{default_frac * 100:.2f}% of training)")
return interval
eval_every = _resolve_progress_interval("eval_every", eval_every, num_iterations, 0.01)
user_config["eval_every"] = eval_every
log_every_steps = _resolve_progress_interval("log_every", log_every, num_iterations, 0.01)
user_config["log_every"] = log_every_steps
sequences_per_step = max(1, total_batch_size // max_seq_len)
checkpoint_every_steps = int(checkpoint_every_steps)
@ -477,7 +512,7 @@ for step in range(num_iterations + 1):
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec (micro): {tok_per_sec:,} | tok/sec (global): {global_tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
if step % 100 == 0:
if step % log_every_steps == 0:
log_payload = {
"step": step,
"total_training_flops": flops_so_far,