mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-20 20:03:19 +00:00
Add kv_head_mult parameter for training and evaluation scripts
- Introduced `kv_head_mult` to control the number of query heads sharing a key/value head in `base_train.py`, `mid_train.py`, and `runmps.sh`. - Updated logging to include global token per second metrics during training. - Added assertions to ensure `kv_head_mult` is valid and properly integrated into model calculations.
This commit is contained in:
parent
b1d49aade5
commit
8a6d34daf7
|
|
@ -129,6 +129,7 @@ BASE_DEPTH=${BASE_DEPTH:-4}
|
|||
SEQ_LEN=${SEQ_LEN:-1024}
|
||||
DEVICE_BATCH=${DEVICE_BATCH:-16}
|
||||
TOTAL_BATCH=${TOTAL_BATCH:-$((DEVICE_BATCH * SEQ_LEN))} # tokens per optimizer step
|
||||
KV_HEAD_MULT=${KV_HEAD_MULT:-1}
|
||||
EVAL_SEQUENCES=10000
|
||||
EVAL_STEPS=$(((EVAL_SEQUENCES + DEVICE_BATCH - 1) / DEVICE_BATCH))
|
||||
EVAL_BATCH_MULT=4 # evaluate on 4 full batches
|
||||
|
|
@ -243,6 +244,7 @@ python -m scripts.tok_eval
|
|||
--max_seq_len=$SEQ_LEN \
|
||||
--device_batch_size=$DEVICE_BATCH \
|
||||
--total_batch_size=$TOTAL_BATCH \
|
||||
--kv_head_mult=$KV_HEAD_MULT \
|
||||
--target_param_data_ratio=$TARGET_PARAM_DATA_RATIO \
|
||||
--run="$WANDB_RUN" \
|
||||
--eval_every=$EVAL_STEPS \
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, i
|
|||
# Model architecture
|
||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
max_seq_len = 2048 # max context length
|
||||
kv_head_mult = 1 # number of query heads that share a single key/value head (1 disables GQA)
|
||||
# Training horizon. Only one of these 3 will be used, in this order of precedence.
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
|
||||
|
|
@ -101,9 +102,12 @@ print0(f"Vocab size: {vocab_size:,}")
|
|||
num_layers = depth
|
||||
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
assert kv_head_mult >= 1, "kv_head_mult must be >= 1"
|
||||
assert num_heads % kv_head_mult == 0, f"num_heads ({num_heads}) must be divisible by kv_head_mult ({kv_head_mult})"
|
||||
num_kv_heads = max(1, num_heads // kv_head_mult)
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim}")
|
||||
print0(f"kv_head_mult: {kv_head_mult}")
|
||||
print0(f"num_heads: {num_heads}")
|
||||
print0(f"num_kv_heads: {num_kv_heads}")
|
||||
|
||||
|
|
@ -338,12 +342,13 @@ for step in range(num_iterations + 1):
|
|||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
||||
global_tok_per_sec = int(total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec (micro): {tok_per_sec:,} | tok/sec (global): {global_tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
if step % 100 == 0:
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
|
|
@ -353,6 +358,7 @@ for step in range(num_iterations + 1):
|
|||
"train/lrm": lrm,
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/tok_per_sec_global": global_tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
"train/total_tokens": total_tokens_seen,
|
||||
"train/total_sequences": total_sequences_seen,
|
||||
|
|
|
|||
|
|
@ -308,12 +308,13 @@ while True:
|
|||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * progress
|
||||
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
||||
global_tok_per_sec = int(total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec (micro): {tok_per_sec:,} | tok/sec (global): {global_tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
if step % 10 == 0:
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
|
|
@ -323,6 +324,7 @@ while True:
|
|||
"train/lrm": lrm,
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/tok_per_sec_global": global_tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
"train/total_tokens": total_tokens_seen,
|
||||
"train/total_sequences": total_sequences_seen,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user