From 9550053cc1d74522ecc12e1a631b060dcb938511 Mon Sep 17 00:00:00 2001 From: William Thurston Date: Mon, 10 Nov 2025 19:45:02 -0800 Subject: [PATCH] Enhance model tagging support in training and evaluation scripts - Added model tagging functionality to `runmps.sh`, allowing for dynamic model tagging based on the W&B run name. - Updated `base_train.py`, `mid_train.py`, and `chat_sft.py` to utilize model tags for checkpoint management. - Enhanced `base_eval.py` to accept model tags for loading models during evaluation. - Improved handling of model tags to ensure proper checkpoint directory naming and logging. --- dev/runmps.sh | 40 ++++++++++++++++++++++++++++++++++------ scripts/base_eval.py | 3 ++- scripts/base_train.py | 15 +++++++++++---- scripts/chat_sft.py | 14 ++++++++++++-- scripts/mid_train.py | 13 +++++++++++-- 5 files changed, 70 insertions(+), 15 deletions(-) diff --git a/dev/runmps.sh b/dev/runmps.sh index e73ea8c..5a3529d 100755 --- a/dev/runmps.sh +++ b/dev/runmps.sh @@ -238,6 +238,13 @@ python -m scripts.tok_eval export WANDB_EVAL_RUN=$WANDB_RUN export WANDB_PROJECT fi + if [ "$WANDB_RUN" != "dummy" ]; then + BASE_MODEL_TAG_FLAG="--model_tag=$WANDB_RUN" + BASE_MODEL_TAG_FLAG_HYPHEN="--model-tag=$WANDB_RUN" + else + BASE_MODEL_TAG_FLAG="" + BASE_MODEL_TAG_FLAG_HYPHEN="" + fi python -m scripts.base_train \ --depth=$BASE_DEPTH \ @@ -251,7 +258,8 @@ python -m scripts.tok_eval --eval_tokens=$EVAL_TOKENS \ --core_metric_every=-1 \ --sample_every=-1 \ - --checkpoint_every_steps=$BASE_CHECKPOINT_STEPS + --checkpoint_every_steps=$BASE_CHECKPOINT_STEPS \ + $BASE_MODEL_TAG_FLAG if [ "$WANDB_RUN" != "dummy" ]; then unset WANDB_RUN_ID @@ -259,8 +267,8 @@ python -m scripts.tok_eval fi if [ "$RUN_STAGE_EVALS" = "1" ]; then - python -m scripts.base_loss --device_batch_size=$DEVICE_BATCH --split_tokens=$EVAL_TOKENS - python -m scripts.base_eval --max-per-task=16 + python -m scripts.base_loss --device_batch_size=$DEVICE_BATCH --split_tokens=$EVAL_TOKENS $BASE_MODEL_TAG_FLAG + python -m scripts.base_eval --max-per-task=16 $BASE_MODEL_TAG_FLAG_HYPHEN fi fi @@ -274,6 +282,15 @@ if (( RUN_MID )); then export WANDB_EVAL_RUN="${WANDB_RUN}-mid" export WANDB_PROJECT fi + if [ "$WANDB_RUN" != "dummy" ]; then + MID_MODEL_TAG_FLAG="--model_tag=$WANDB_RUN" + MID_OUTPUT_MODEL_TAG_FLAG="--output_model_tag=${WANDB_RUN}-mid" + MID_EVAL_MODEL_TAG_FLAG="--model_tag=${WANDB_RUN}-mid" + else + MID_MODEL_TAG_FLAG="" + MID_OUTPUT_MODEL_TAG_FLAG="" + MID_EVAL_MODEL_TAG_FLAG="" + fi python -m scripts.mid_train \ --max_seq_len=$SEQ_LEN \ @@ -283,7 +300,9 @@ if (( RUN_MID )); then --eval_every=$EVAL_STEPS \ --eval_tokens=$EVAL_TOKENS \ --checkpoint_every_steps=$MID_CHECKPOINT_STEPS \ - --num_iterations=$MID_NUM_STEPS + --num_iterations=$MID_NUM_STEPS \ + $MID_MODEL_TAG_FLAG \ + $MID_OUTPUT_MODEL_TAG_FLAG if [ "$WANDB_RUN" != "dummy" ]; then unset WANDB_RUN_ID unset WANDB_EVAL_RUN @@ -291,7 +310,7 @@ if (( RUN_MID )); then if [ "$RUN_STAGE_EVALS" = "1" ]; then # eval results will be terrible, this is just to execute the code paths. # note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems - python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20 + python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20 $MID_EVAL_MODEL_TAG_FLAG fi fi @@ -305,6 +324,13 @@ if (( RUN_SFT )); then export WANDB_EVAL_RUN="${WANDB_RUN}-sft" export WANDB_PROJECT fi + if [ "$WANDB_RUN" != "dummy" ]; then + SFT_MODEL_TAG_FLAG="--model_tag=${WANDB_RUN}-mid" + SFT_OUTPUT_MODEL_TAG_FLAG="--output_model_tag=${WANDB_RUN}-sft" + else + SFT_MODEL_TAG_FLAG="" + SFT_OUTPUT_MODEL_TAG_FLAG="" + fi python -m scripts.chat_sft \ --device_batch_size=$SFT_DEVICE_BATCH \ @@ -315,7 +341,9 @@ if (( RUN_SFT )); then --eval_steps=$SFT_EVAL_STEPS \ --eval_metrics_every=$SFT_EVAL_METRICS_EVERY \ --eval_metrics_max_problems=$SFT_EVAL_METRICS_MAX \ - --checkpoint_every_steps=$SFT_CHECKPOINT_STEPS + --checkpoint_every_steps=$SFT_CHECKPOINT_STEPS \ + $SFT_MODEL_TAG_FLAG \ + $SFT_OUTPUT_MODEL_TAG_FLAG if [ "$WANDB_RUN" != "dummy" ]; then unset WANDB_RUN_ID diff --git a/scripts/base_eval.py b/scripts/base_eval.py index f6f002e..cb79140 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -125,6 +125,7 @@ def main(): parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate') parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)') parser.add_argument('--model-step', type=int, default=None, help='Checkpoint step to evaluate when using local models') + parser.add_argument('--model-tag', type=str, default=None, help='Model tag to load from base_checkpoints when evaluating local models') args = parser.parse_args() # distributed / precision setup @@ -144,7 +145,7 @@ def main(): model_slug = hf_path.replace("/", "-") # for the output csv file else: # load a local model from the file system - model, tokenizer, meta = load_model("base", device, phase="eval", step=args.model_step) + model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step) model_name = f"base_model (step {meta['step']})" # just for logging model_slug = f"base_model_{meta['step']:06d}" # for the output csv file if use_wandb: diff --git a/scripts/base_train.py b/scripts/base_train.py index 1d9cd4e..914df17 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -105,11 +105,21 @@ num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here i assert kv_head_mult >= 1, "kv_head_mult must be >= 1" assert num_heads % kv_head_mult == 0, f"num_heads ({num_heads}) must be divisible by kv_head_mult ({kv_head_mult})" num_kv_heads = max(1, num_heads // kv_head_mult) +def _resolve_checkpoint_tag(tag, run_name, depth_value): + if tag: + return tag + run_name = run_name or "" + if run_name and run_name != "dummy": + return run_name + return f"d{depth_value}" +model_tag = _resolve_checkpoint_tag(model_tag, run, depth) +user_config["model_tag"] = model_tag print0(f"num_layers: {num_layers}") print0(f"model_dim: {model_dim}") print0(f"kv_head_mult: {kv_head_mult}") print0(f"num_heads: {num_heads}") print0(f"num_kv_heads: {num_kv_heads}") +print0(f"Checkpoint tag: {model_tag}") # Optimizer / data / training length related hyperparameters # figure out the needed gradient accumulation to reach the desired total batch size @@ -172,8 +182,7 @@ build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, x, y = next(train_loader) # kick off load of the very first batch of data # Checkpoint output location -checkpoint_dirname = model_tag if model_tag else f"d{depth}" -checkpoint_dir = os.path.join(base_dir, "base_checkpoints", checkpoint_dirname) +checkpoint_dir = os.path.join(base_dir, "base_checkpoints", model_tag) # ----------------------------------------------------------------------------- # Set up hyperparameter schedulers @@ -209,8 +218,6 @@ total_sequences_seen = 0 last_val_bpb = None def save_base_checkpoint(step_idx): - output_dirname = model_tag if model_tag else f"d{depth}" - checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname) optimizer_state = [opt.state_dict() for opt in optimizers] meta = { "step": step_idx, diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 25ba0fa..4550f15 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -37,6 +37,7 @@ run = "dummy" # wandb run name default ("dummy" is special - we won't log to wan source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model) model_tag = None # model tag to load the model from (base model or midtrained model) step = None # step to load the model from (base model or midtrained model) +output_model_tag = "" # optional override for the checkpoint directory where SFT checkpoints are saved # compute/precision device_type = "" # cuda|cpu|mps (empty => autodetect) dtype = "bfloat16" @@ -102,6 +103,13 @@ orig_model = model # original, uncompiled model # model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs engine = Engine(model, tokenizer) # will be used for inline model evaluation only model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer +def _resolve_checkpoint_tag(tag, run_name, fallback_tag): + if tag: + return tag + run_name = run_name or "" + if run_name and run_name != "dummy": + return run_name + return fallback_tag # ----------------------------------------------------------------------------- # Task data mixture we'll train on @@ -172,8 +180,10 @@ checkpoint_every_steps = int(checkpoint_every_steps) checkpoint_enabled = checkpoint_every_steps > 0 base_dir = get_base_dir() -checkpoint_dirname = model_tag if model_tag else f"d{model.config.n_layer}" -checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", checkpoint_dirname) +fallback_checkpoint_tag = model_tag if model_tag else f"d{model.config.n_layer}" +checkpoint_tag = _resolve_checkpoint_tag(output_model_tag, run, fallback_checkpoint_tag) +print0(f"Checkpoint tag: {checkpoint_tag}") +checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", checkpoint_tag) # ----------------------------------------------------------------------------- # Initialize the Optimizer diff --git a/scripts/mid_train.py b/scripts/mid_train.py index c1a9d0f..c1f3902 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -35,6 +35,7 @@ run = "dummy" # wandb run name default ("dummy" is special - we won't log to wan device_type = "" # cuda|cpu|mps (empty => autodetect) model_tag = None # model tag to load the model from (base model or midtrained model) step = None # step to load the model from (base model or midtrained model) +output_model_tag = "" # optional override for the checkpoint directory name we save midtraining snapshots to dtype = "bfloat16" num_iterations = -1 # explicit number of steps of the optimization (-1 = disable) max_seq_len = 2048 @@ -86,6 +87,14 @@ if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size: orig_model = model model = torch.compile(model, dynamic=False) depth = model.config.n_layer +def _resolve_checkpoint_tag(tag, run_name, depth_value): + if tag: + return tag + run_name = run_name or "" + if run_name and run_name != "dummy": + return run_name + return f"d{depth_value}" +checkpoint_tag = _resolve_checkpoint_tag(output_model_tag, run, depth) num_flops_per_token = model.estimate_flops() tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks @@ -110,6 +119,7 @@ for opt in optimizers: group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later # Midtraining data mixture and DataLoader +print0(f"Checkpoint tag: {checkpoint_tag}") base_dir = get_base_dir() identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") train_dataset = TaskMixture([ @@ -176,8 +186,7 @@ train_loader = mid_data_generator("train") build_val_loader = lambda: mid_data_generator("val") progress = 0 # will go from 0 to 1 over the course of the epoch -checkpoint_dirname = f"d{depth}" -checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", checkpoint_dirname) +checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", checkpoint_tag) # Learning rate scheduler def get_lr_multiplier(progress):