Add kv_head_mult parameter for training and evaluation scripts

- Introduced `kv_head_mult` to control the number of query heads sharing a key/value head in `base_train.py`, `mid_train.py`, and `runmps.sh`.
- Updated logging to include global token per second metrics during training.
- Added assertions to ensure `kv_head_mult` is valid and properly integrated into model calculations.
This commit is contained in:
William Thurston 2025-11-09 14:23:45 -08:00
parent b1d49aade5
commit 8a6d34daf7
3 changed files with 13 additions and 3 deletions

View File

@ -129,6 +129,7 @@ BASE_DEPTH=${BASE_DEPTH:-4}
SEQ_LEN=${SEQ_LEN:-1024}
DEVICE_BATCH=${DEVICE_BATCH:-16}
TOTAL_BATCH=${TOTAL_BATCH:-$((DEVICE_BATCH * SEQ_LEN))} # tokens per optimizer step
KV_HEAD_MULT=${KV_HEAD_MULT:-1}
EVAL_SEQUENCES=10000
EVAL_STEPS=$(((EVAL_SEQUENCES + DEVICE_BATCH - 1) / DEVICE_BATCH))
EVAL_BATCH_MULT=4 # evaluate on 4 full batches
@ -243,6 +244,7 @@ python -m scripts.tok_eval
--max_seq_len=$SEQ_LEN \
--device_batch_size=$DEVICE_BATCH \
--total_batch_size=$TOTAL_BATCH \
--kv_head_mult=$KV_HEAD_MULT \
--target_param_data_ratio=$TARGET_PARAM_DATA_RATIO \
--run="$WANDB_RUN" \
--eval_every=$EVAL_STEPS \

View File

@ -37,6 +37,7 @@ device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, i
# Model architecture
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
max_seq_len = 2048 # max context length
kv_head_mult = 1 # number of query heads that share a single key/value head (1 disables GQA)
# Training horizon. Only one of these 3 will be used, in this order of precedence.
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
@ -101,9 +102,12 @@ print0(f"Vocab size: {vocab_size:,}")
num_layers = depth
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
assert kv_head_mult >= 1, "kv_head_mult must be >= 1"
assert num_heads % kv_head_mult == 0, f"num_heads ({num_heads}) must be divisible by kv_head_mult ({kv_head_mult})"
num_kv_heads = max(1, num_heads // kv_head_mult)
print0(f"num_layers: {num_layers}")
print0(f"model_dim: {model_dim}")
print0(f"kv_head_mult: {kv_head_mult}")
print0(f"num_heads: {num_heads}")
print0(f"num_kv_heads: {num_kv_heads}")
@ -338,12 +342,13 @@ for step in range(num_iterations + 1):
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * step / num_iterations
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
global_tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec (micro): {tok_per_sec:,} | tok/sec (global): {global_tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
if step % 100 == 0:
wandb_run.log({
"step": step,
@ -353,6 +358,7 @@ for step in range(num_iterations + 1):
"train/lrm": lrm,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/tok_per_sec_global": global_tok_per_sec,
"train/mfu": mfu,
"train/total_tokens": total_tokens_seen,
"train/total_sequences": total_sequences_seen,

View File

@ -308,12 +308,13 @@ while True:
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * progress
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
global_tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec (micro): {tok_per_sec:,} | tok/sec (global): {global_tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
if step % 10 == 0:
wandb_run.log({
"step": step,
@ -323,6 +324,7 @@ while True:
"train/lrm": lrm,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/tok_per_sec_global": global_tok_per_sec,
"train/mfu": mfu,
"train/total_tokens": total_tokens_seen,
"train/total_sequences": total_sequences_seen,