mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-20 03:43:20 +00:00
Add tie_embeddings support and configurable logging interval
Implement weight tying between token embeddings and lm_head to reduce parameter count. When enabled, logits are scaled by 1/√d_model, lm_head zeroing is skipped, and optimizer groups are deduplicated. Param counting uses unique parameters while Chinchilla ratio calculation adds back the would-be lm_head size for comparability. Also adds boolean flag parsing (--flag without =value) to the configurator, an auto-derived log_every interval, and minor shell script fixes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
5f13389568
commit
0c942a8c00
|
|
@ -276,6 +276,9 @@ python -m scripts.tok_eval
|
|||
MOE_FLAGS+=("--moe_activation_denominator=$MOE_ACTIVATION_DEN")
|
||||
fi
|
||||
fi
|
||||
if [ "${TIE_EMBEDDINGS:-0}" -ne 0 ]; then
|
||||
MOE_FLAGS+=("--tie_embeddings=1")
|
||||
fi
|
||||
|
||||
MOE_DEBUG_INTERVAL=$MOE_DEBUG_INTERVAL python -m scripts.base_train \
|
||||
--depth=$BASE_DEPTH \
|
||||
|
|
@ -298,8 +301,8 @@ python -m scripts.tok_eval
|
|||
fi
|
||||
|
||||
if [ "$RUN_STAGE_EVALS" = "1" ]; then
|
||||
python -m scripts.base_loss --device_batch_size=$DEVICE_BATCH --split_tokens=$EVAL_TOKENS $BASE_MODEL_TAG_FLAG
|
||||
python -m scripts.base_eval --max-per-task=16 $BASE_MODEL_TAG_FLAG_HYPHEN
|
||||
python -m scripts.base_loss --device_batch_size=$DEVICE_BATCH --split_tokens=$EVAL_TOKENS $BASE_MODEL_TAG_FLAG
|
||||
python -m scripts.base_eval --max-per-task=16 "$BASE_MODEL_TAG_FLAG_HYPHEN"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -25,8 +25,16 @@ def print0(s="",**kwargs):
|
|||
|
||||
for arg in sys.argv[1:]:
|
||||
if '=' not in arg:
|
||||
if arg.startswith('--'):
|
||||
# boolean flag form: --flag sets True if default exists and is bool
|
||||
key = arg[2:]
|
||||
if key in globals() and isinstance(globals()[key], bool):
|
||||
print0(f"Overriding: {key} = True")
|
||||
globals()[key] = True
|
||||
else:
|
||||
raise ValueError(f"Unknown or non-bool config key for flag: {key}")
|
||||
continue
|
||||
# assume it's the name of a config file
|
||||
assert not arg.startswith('--')
|
||||
config_file = arg
|
||||
print0(f"Overriding config with {config_file}:")
|
||||
with open(config_file) as f:
|
||||
|
|
@ -50,6 +58,8 @@ for arg in sys.argv[1:]:
|
|||
default_type = type(globals()[key])
|
||||
if default_type is float and attempt_type in (int, float):
|
||||
attempt = float(attempt)
|
||||
elif default_type is bool and attempt_type in (int, bool):
|
||||
attempt = bool(attempt)
|
||||
else:
|
||||
assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}"
|
||||
# cross fingers
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ class GPTConfig:
|
|||
dense_layers_before_moe: int = 0
|
||||
moe_granularity_target: float = 12.0
|
||||
moe_activation_denominator: int = 32
|
||||
tie_embeddings: bool = False
|
||||
|
||||
|
||||
def norm(x):
|
||||
|
|
@ -361,6 +362,9 @@ class GPT(nn.Module):
|
|||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||
})
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
if config.tie_embeddings:
|
||||
# share weights instead of keeping a duplicate parameter
|
||||
self.lm_head.weight = self.transformer.wte.weight
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's fake
|
||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them, but assert fail if we ever reach that amount.
|
||||
|
|
@ -373,8 +377,9 @@ class GPT(nn.Module):
|
|||
|
||||
def init_weights(self):
|
||||
self.apply(self._init_weights)
|
||||
# zero out classifier weights
|
||||
torch.nn.init.zeros_(self.lm_head.weight)
|
||||
# zero out classifier weights when untied
|
||||
if not self.config.tie_embeddings:
|
||||
torch.nn.init.zeros_(self.lm_head.weight)
|
||||
# zero out c_proj weights in all blocks
|
||||
for block in self.transformer.h:
|
||||
if isinstance(block.mlp, MoEFeedForward):
|
||||
|
|
@ -488,17 +493,25 @@ class GPT(nn.Module):
|
|||
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
|
||||
matrix_params = list(self.transformer.h.parameters())
|
||||
embedding_params = list(self.transformer.wte.parameters())
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
|
||||
# When tied, lm_head shares the embedding weight, so keep its param out of optimizer groups to avoid duplicates.
|
||||
lm_head_params = [] if self.config.tie_embeddings else list(self.lm_head.parameters())
|
||||
# Sanity check: all parameters should be covered by the optimizer groups (ignoring intentional sharing).
|
||||
if not self.config.tie_embeddings:
|
||||
param_ids = {id(p) for p in self.parameters()}
|
||||
assert param_ids == {id(p) for p in matrix_params + embedding_params + lm_head_params}
|
||||
# Create the AdamW optimizer for the embedding and lm_head
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
if rank == 0:
|
||||
print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
# Use the embedding LR for embeddings (even when tied); head params use unembedding LR when present.
|
||||
embedding_lr_effective = embedding_lr
|
||||
adam_groups = [
|
||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||
dict(params=embedding_params, lr=embedding_lr_effective * dmodel_lr_scale),
|
||||
]
|
||||
# drop empty groups to avoid optimizer complaints when embeddings are tied
|
||||
adam_groups = [g for g in adam_groups if g["params"]]
|
||||
adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
|
||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
||||
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
|
||||
|
|
@ -533,10 +546,13 @@ class GPT(nn.Module):
|
|||
|
||||
# Forward the lm_head (compute logits)
|
||||
softcap = 15
|
||||
logit_scale = (self.config.n_embd ** -0.5) if self.config.tie_embeddings else 1.0
|
||||
if targets is not None:
|
||||
# training mode: compute and return the loss
|
||||
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
|
||||
logits = self.lm_head(x)
|
||||
if logit_scale != 1.0:
|
||||
logits = logits * logit_scale
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
logits = logits.float() # use tf32/fp32 for logits
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
||||
|
|
@ -544,6 +560,8 @@ class GPT(nn.Module):
|
|||
else:
|
||||
# inference mode: compute and return the logits
|
||||
logits = self.lm_head(x)
|
||||
if logit_scale != 1.0:
|
||||
logits = logits * logit_scale
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
return logits
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Ex
|
|||
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
|
|
@ -45,6 +46,7 @@ moe_expert_ffn_mult = -1.0 # -1 => derive from granularity target (defaults to 1
|
|||
dense_layers_before_moe = -1 # -1 => derive (≈10% of layers, min 1) before switching to MoE
|
||||
moe_granularity_target = 12.0 # Ling guidance: target granularity per layer (2*d_model/d_expert)
|
||||
moe_activation_denominator = 32 # derive top-k as num_experts / denominator (~3% activation)
|
||||
tie_embeddings = False # tie output head to token embeddings (saves params)
|
||||
# Training horizon. Only one of these 3 will be used, in this order of precedence.
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
|
||||
|
|
@ -61,11 +63,12 @@ warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
|||
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||
# Evaluation
|
||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||
eval_every = 0 # <=0 => auto (~1% of total training steps) evaluation cadence
|
||||
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
||||
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
|
||||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||
sample_every = 2000 # every how many steps to sample from the model
|
||||
log_every = 0 # <=0 => auto (~1% of total training steps) logging cadence for train/MoE metrics
|
||||
# Output
|
||||
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
||||
checkpoint_every_steps = 0 # save intermediate checkpoints every N optimization steps (0 = disable)
|
||||
|
|
@ -192,24 +195,47 @@ model_config_kwargs = dict(
|
|||
dense_layers_before_moe=dense_layers_before_moe,
|
||||
moe_granularity_target=granularity_target,
|
||||
moe_activation_denominator=activation_denom,
|
||||
tie_embeddings=tie_embeddings,
|
||||
)
|
||||
with torch.device("meta"):
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
model = GPT(model_config)
|
||||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
# Re-assert tying at runtime (belt-and-suspenders) and emit a sanity check.
|
||||
if tie_embeddings:
|
||||
model.lm_head.weight = model.transformer.wte.weight
|
||||
assert model.lm_head.weight is model.transformer.wte.weight, "tie_embeddings requested but weights not shared"
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||
dense_like_flops = model.estimate_flops()
|
||||
active_flops_per_token, dense_ref_flops = model.estimate_moe_active_flops()
|
||||
num_flops_per_token = active_flops_per_token
|
||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
|
||||
def _count_unique_params(m):
|
||||
seen = set()
|
||||
total = 0
|
||||
for p in m.parameters():
|
||||
pid = id(p)
|
||||
if pid in seen:
|
||||
continue
|
||||
seen.add(pid)
|
||||
total += p.numel()
|
||||
return total
|
||||
|
||||
num_params = _count_unique_params(model)
|
||||
print0(f"Number of parameters: {num_params:,}")
|
||||
print0(f"Estimated FLOPs per token (dense-like): {dense_like_flops:e}")
|
||||
if active_flops_per_token != dense_like_flops:
|
||||
print0(f"Estimated FLOPs per token (MoE active): {active_flops_per_token:e}")
|
||||
print0(f"Estimated FLOPs per token (dense reference): {dense_ref_flops:e}")
|
||||
|
||||
# If embeddings are tied, keep the data:parameter ratio comparable to untied models by
|
||||
# counting the would-be lm_head parameters toward the ratio calculation.
|
||||
num_params_for_ratio = _count_unique_params(model)
|
||||
if tie_embeddings:
|
||||
num_params_for_ratio += vocab_size * model_dim
|
||||
|
||||
user_config.update({
|
||||
"moe_num_experts": moe_num_experts,
|
||||
"moe_num_shared_experts": moe_num_shared_experts,
|
||||
|
|
@ -246,18 +272,27 @@ elif target_flops > 0:
|
|||
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
|
||||
elif target_param_data_ratio > 0:
|
||||
# calculate the number of iterations from the target param data ratio
|
||||
target_tokens = target_param_data_ratio * num_params
|
||||
target_tokens = target_param_data_ratio * num_params_for_ratio
|
||||
num_iterations = target_tokens // total_batch_size
|
||||
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
||||
else:
|
||||
raise ValueError("No training horizon specified")
|
||||
total_tokens = total_batch_size * num_iterations
|
||||
print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params_for_ratio:.2f}") # Chinchilla is ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
if eval_every <= 0:
|
||||
eval_every = max(1, num_iterations // 100)
|
||||
print0(f"Auto-setting eval_every to {eval_every} (~1% of training)")
|
||||
def _resolve_progress_interval(name, configured_value, total_steps, default_frac):
|
||||
if configured_value > 0:
|
||||
return max(1, int(configured_value))
|
||||
interval = max(1, math.ceil(total_steps * default_frac))
|
||||
print0(f"Auto-setting {name} to {interval} (~{default_frac * 100:.2f}% of training)")
|
||||
return interval
|
||||
|
||||
eval_every = _resolve_progress_interval("eval_every", eval_every, num_iterations, 0.01)
|
||||
user_config["eval_every"] = eval_every
|
||||
|
||||
log_every_steps = _resolve_progress_interval("log_every", log_every, num_iterations, 0.01)
|
||||
user_config["log_every"] = log_every_steps
|
||||
|
||||
sequences_per_step = max(1, total_batch_size // max_seq_len)
|
||||
checkpoint_every_steps = int(checkpoint_every_steps)
|
||||
|
|
@ -477,7 +512,7 @@ for step in range(num_iterations + 1):
|
|||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec (micro): {tok_per_sec:,} | tok/sec (global): {global_tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
if step % 100 == 0:
|
||||
if step % log_every_steps == 0:
|
||||
log_payload = {
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user