diff --git a/dev/runmps.sh b/dev/runmps.sh index da8e216..e73ea8c 100755 --- a/dev/runmps.sh +++ b/dev/runmps.sh @@ -129,6 +129,7 @@ BASE_DEPTH=${BASE_DEPTH:-4} SEQ_LEN=${SEQ_LEN:-1024} DEVICE_BATCH=${DEVICE_BATCH:-16} TOTAL_BATCH=${TOTAL_BATCH:-$((DEVICE_BATCH * SEQ_LEN))} # tokens per optimizer step +KV_HEAD_MULT=${KV_HEAD_MULT:-1} EVAL_SEQUENCES=10000 EVAL_STEPS=$(((EVAL_SEQUENCES + DEVICE_BATCH - 1) / DEVICE_BATCH)) EVAL_BATCH_MULT=4 # evaluate on 4 full batches @@ -243,6 +244,7 @@ python -m scripts.tok_eval --max_seq_len=$SEQ_LEN \ --device_batch_size=$DEVICE_BATCH \ --total_batch_size=$TOTAL_BATCH \ + --kv_head_mult=$KV_HEAD_MULT \ --target_param_data_ratio=$TARGET_PARAM_DATA_RATIO \ --run="$WANDB_RUN" \ --eval_every=$EVAL_STEPS \ diff --git a/scripts/base_train.py b/scripts/base_train.py index ee394ca..1d9cd4e 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -37,6 +37,7 @@ device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, i # Model architecture depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived max_seq_len = 2048 # max context length +kv_head_mult = 1 # number of query heads that share a single key/value head (1 disables GQA) # Training horizon. Only one of these 3 will be used, in this order of precedence. num_iterations = -1 # explicit number of steps of the optimization (-1 = disable) target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable) @@ -101,9 +102,12 @@ print0(f"Vocab size: {vocab_size:,}") num_layers = depth model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases) num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div) -num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled) +assert kv_head_mult >= 1, "kv_head_mult must be >= 1" +assert num_heads % kv_head_mult == 0, f"num_heads ({num_heads}) must be divisible by kv_head_mult ({kv_head_mult})" +num_kv_heads = max(1, num_heads // kv_head_mult) print0(f"num_layers: {num_layers}") print0(f"model_dim: {model_dim}") +print0(f"kv_head_mult: {kv_head_mult}") print0(f"num_heads: {num_heads}") print0(f"num_kv_heads: {num_kv_heads}") @@ -338,12 +342,13 @@ for step in range(num_iterations + 1): debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA pct_done = 100 * step / num_iterations tok_per_sec = int(world_tokens_per_fwdbwd / dt) + global_tok_per_sec = int(total_batch_size / dt) flops_per_sec = num_flops_per_token * total_batch_size / dt promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % if step > 10: total_training_time += dt # only count the time after the first 10 steps - print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") + print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec (micro): {tok_per_sec:,} | tok/sec (global): {global_tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") if step % 100 == 0: wandb_run.log({ "step": step, @@ -353,6 +358,7 @@ for step in range(num_iterations + 1): "train/lrm": lrm, "train/dt": dt, "train/tok_per_sec": tok_per_sec, + "train/tok_per_sec_global": global_tok_per_sec, "train/mfu": mfu, "train/total_tokens": total_tokens_seen, "train/total_sequences": total_sequences_seen, diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 858d9c4..c1a9d0f 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -308,12 +308,13 @@ while True: debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA pct_done = 100 * progress tok_per_sec = int(world_tokens_per_fwdbwd / dt) + global_tok_per_sec = int(total_batch_size / dt) flops_per_sec = num_flops_per_token * total_batch_size / dt promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % if step > 10: total_training_time += dt # only count the time after the first 10 steps - print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") + print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec (micro): {tok_per_sec:,} | tok/sec (global): {global_tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") if step % 10 == 0: wandb_run.log({ "step": step, @@ -323,6 +324,7 @@ while True: "train/lrm": lrm, "train/dt": dt, "train/tok_per_sec": tok_per_sec, + "train/tok_per_sec_global": global_tok_per_sec, "train/mfu": mfu, "train/total_tokens": total_tokens_seen, "train/total_sequences": total_sequences_seen,