Enhance model tagging support in training and evaluation scripts

- Added model tagging functionality to `runmps.sh`, allowing for dynamic model tagging based on the W&B run name.
- Updated `base_train.py`, `mid_train.py`, and `chat_sft.py` to utilize model tags for checkpoint management.
- Enhanced `base_eval.py` to accept model tags for loading models during evaluation.
- Improved handling of model tags to ensure proper checkpoint directory naming and logging.
This commit is contained in:
William Thurston 2025-11-10 19:45:02 -08:00
parent 8a6d34daf7
commit 9550053cc1
5 changed files with 70 additions and 15 deletions

View File

@ -238,6 +238,13 @@ python -m scripts.tok_eval
export WANDB_EVAL_RUN=$WANDB_RUN
export WANDB_PROJECT
fi
if [ "$WANDB_RUN" != "dummy" ]; then
BASE_MODEL_TAG_FLAG="--model_tag=$WANDB_RUN"
BASE_MODEL_TAG_FLAG_HYPHEN="--model-tag=$WANDB_RUN"
else
BASE_MODEL_TAG_FLAG=""
BASE_MODEL_TAG_FLAG_HYPHEN=""
fi
python -m scripts.base_train \
--depth=$BASE_DEPTH \
@ -251,7 +258,8 @@ python -m scripts.tok_eval
--eval_tokens=$EVAL_TOKENS \
--core_metric_every=-1 \
--sample_every=-1 \
--checkpoint_every_steps=$BASE_CHECKPOINT_STEPS
--checkpoint_every_steps=$BASE_CHECKPOINT_STEPS \
$BASE_MODEL_TAG_FLAG
if [ "$WANDB_RUN" != "dummy" ]; then
unset WANDB_RUN_ID
@ -259,8 +267,8 @@ python -m scripts.tok_eval
fi
if [ "$RUN_STAGE_EVALS" = "1" ]; then
python -m scripts.base_loss --device_batch_size=$DEVICE_BATCH --split_tokens=$EVAL_TOKENS
python -m scripts.base_eval --max-per-task=16
python -m scripts.base_loss --device_batch_size=$DEVICE_BATCH --split_tokens=$EVAL_TOKENS $BASE_MODEL_TAG_FLAG
python -m scripts.base_eval --max-per-task=16 $BASE_MODEL_TAG_FLAG_HYPHEN
fi
fi
@ -274,6 +282,15 @@ if (( RUN_MID )); then
export WANDB_EVAL_RUN="${WANDB_RUN}-mid"
export WANDB_PROJECT
fi
if [ "$WANDB_RUN" != "dummy" ]; then
MID_MODEL_TAG_FLAG="--model_tag=$WANDB_RUN"
MID_OUTPUT_MODEL_TAG_FLAG="--output_model_tag=${WANDB_RUN}-mid"
MID_EVAL_MODEL_TAG_FLAG="--model_tag=${WANDB_RUN}-mid"
else
MID_MODEL_TAG_FLAG=""
MID_OUTPUT_MODEL_TAG_FLAG=""
MID_EVAL_MODEL_TAG_FLAG=""
fi
python -m scripts.mid_train \
--max_seq_len=$SEQ_LEN \
@ -283,7 +300,9 @@ if (( RUN_MID )); then
--eval_every=$EVAL_STEPS \
--eval_tokens=$EVAL_TOKENS \
--checkpoint_every_steps=$MID_CHECKPOINT_STEPS \
--num_iterations=$MID_NUM_STEPS
--num_iterations=$MID_NUM_STEPS \
$MID_MODEL_TAG_FLAG \
$MID_OUTPUT_MODEL_TAG_FLAG
if [ "$WANDB_RUN" != "dummy" ]; then
unset WANDB_RUN_ID
unset WANDB_EVAL_RUN
@ -291,7 +310,7 @@ if (( RUN_MID )); then
if [ "$RUN_STAGE_EVALS" = "1" ]; then
# eval results will be terrible, this is just to execute the code paths.
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20 $MID_EVAL_MODEL_TAG_FLAG
fi
fi
@ -305,6 +324,13 @@ if (( RUN_SFT )); then
export WANDB_EVAL_RUN="${WANDB_RUN}-sft"
export WANDB_PROJECT
fi
if [ "$WANDB_RUN" != "dummy" ]; then
SFT_MODEL_TAG_FLAG="--model_tag=${WANDB_RUN}-mid"
SFT_OUTPUT_MODEL_TAG_FLAG="--output_model_tag=${WANDB_RUN}-sft"
else
SFT_MODEL_TAG_FLAG=""
SFT_OUTPUT_MODEL_TAG_FLAG=""
fi
python -m scripts.chat_sft \
--device_batch_size=$SFT_DEVICE_BATCH \
@ -315,7 +341,9 @@ if (( RUN_SFT )); then
--eval_steps=$SFT_EVAL_STEPS \
--eval_metrics_every=$SFT_EVAL_METRICS_EVERY \
--eval_metrics_max_problems=$SFT_EVAL_METRICS_MAX \
--checkpoint_every_steps=$SFT_CHECKPOINT_STEPS
--checkpoint_every_steps=$SFT_CHECKPOINT_STEPS \
$SFT_MODEL_TAG_FLAG \
$SFT_OUTPUT_MODEL_TAG_FLAG
if [ "$WANDB_RUN" != "dummy" ]; then
unset WANDB_RUN_ID

View File

@ -125,6 +125,7 @@ def main():
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
parser.add_argument('--model-step', type=int, default=None, help='Checkpoint step to evaluate when using local models')
parser.add_argument('--model-tag', type=str, default=None, help='Model tag to load from base_checkpoints when evaluating local models')
args = parser.parse_args()
# distributed / precision setup
@ -144,7 +145,7 @@ def main():
model_slug = hf_path.replace("/", "-") # for the output csv file
else:
# load a local model from the file system
model, tokenizer, meta = load_model("base", device, phase="eval", step=args.model_step)
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
model_name = f"base_model (step {meta['step']})" # just for logging
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
if use_wandb:

View File

@ -105,11 +105,21 @@ num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here i
assert kv_head_mult >= 1, "kv_head_mult must be >= 1"
assert num_heads % kv_head_mult == 0, f"num_heads ({num_heads}) must be divisible by kv_head_mult ({kv_head_mult})"
num_kv_heads = max(1, num_heads // kv_head_mult)
def _resolve_checkpoint_tag(tag, run_name, depth_value):
if tag:
return tag
run_name = run_name or ""
if run_name and run_name != "dummy":
return run_name
return f"d{depth_value}"
model_tag = _resolve_checkpoint_tag(model_tag, run, depth)
user_config["model_tag"] = model_tag
print0(f"num_layers: {num_layers}")
print0(f"model_dim: {model_dim}")
print0(f"kv_head_mult: {kv_head_mult}")
print0(f"num_heads: {num_heads}")
print0(f"num_kv_heads: {num_kv_heads}")
print0(f"Checkpoint tag: {model_tag}")
# Optimizer / data / training length related hyperparameters
# figure out the needed gradient accumulation to reach the desired total batch size
@ -172,8 +182,7 @@ build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size,
x, y = next(train_loader) # kick off load of the very first batch of data
# Checkpoint output location
checkpoint_dirname = model_tag if model_tag else f"d{depth}"
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", checkpoint_dirname)
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", model_tag)
# -----------------------------------------------------------------------------
# Set up hyperparameter schedulers
@ -209,8 +218,6 @@ total_sequences_seen = 0
last_val_bpb = None
def save_base_checkpoint(step_idx):
output_dirname = model_tag if model_tag else f"d{depth}"
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
optimizer_state = [opt.state_dict() for opt in optimizers]
meta = {
"step": step_idx,

View File

@ -37,6 +37,7 @@ run = "dummy" # wandb run name default ("dummy" is special - we won't log to wan
source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model)
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
output_model_tag = "" # optional override for the checkpoint directory where SFT checkpoints are saved
# compute/precision
device_type = "" # cuda|cpu|mps (empty => autodetect)
dtype = "bfloat16"
@ -102,6 +103,13 @@ orig_model = model # original, uncompiled model
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
def _resolve_checkpoint_tag(tag, run_name, fallback_tag):
if tag:
return tag
run_name = run_name or ""
if run_name and run_name != "dummy":
return run_name
return fallback_tag
# -----------------------------------------------------------------------------
# Task data mixture we'll train on
@ -172,8 +180,10 @@ checkpoint_every_steps = int(checkpoint_every_steps)
checkpoint_enabled = checkpoint_every_steps > 0
base_dir = get_base_dir()
checkpoint_dirname = model_tag if model_tag else f"d{model.config.n_layer}"
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", checkpoint_dirname)
fallback_checkpoint_tag = model_tag if model_tag else f"d{model.config.n_layer}"
checkpoint_tag = _resolve_checkpoint_tag(output_model_tag, run, fallback_checkpoint_tag)
print0(f"Checkpoint tag: {checkpoint_tag}")
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", checkpoint_tag)
# -----------------------------------------------------------------------------
# Initialize the Optimizer

View File

@ -35,6 +35,7 @@ run = "dummy" # wandb run name default ("dummy" is special - we won't log to wan
device_type = "" # cuda|cpu|mps (empty => autodetect)
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
output_model_tag = "" # optional override for the checkpoint directory name we save midtraining snapshots to
dtype = "bfloat16"
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
max_seq_len = 2048
@ -86,6 +87,14 @@ if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
orig_model = model
model = torch.compile(model, dynamic=False)
depth = model.config.n_layer
def _resolve_checkpoint_tag(tag, run_name, depth_value):
if tag:
return tag
run_name = run_name or ""
if run_name and run_name != "dummy":
return run_name
return f"d{depth_value}"
checkpoint_tag = _resolve_checkpoint_tag(output_model_tag, run, depth)
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
@ -110,6 +119,7 @@ for opt in optimizers:
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# Midtraining data mixture and DataLoader
print0(f"Checkpoint tag: {checkpoint_tag}")
base_dir = get_base_dir()
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
train_dataset = TaskMixture([
@ -176,8 +186,7 @@ train_loader = mid_data_generator("train")
build_val_loader = lambda: mid_data_generator("val")
progress = 0 # will go from 0 to 1 over the course of the epoch
checkpoint_dirname = f"d{depth}"
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", checkpoint_dirname)
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", checkpoint_tag)
# Learning rate scheduler
def get_lr_multiplier(progress):