mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-26 06:35:15 +00:00
Enhance model tagging support in training and evaluation scripts
- Added model tagging functionality to `runmps.sh`, allowing for dynamic model tagging based on the W&B run name. - Updated `base_train.py`, `mid_train.py`, and `chat_sft.py` to utilize model tags for checkpoint management. - Enhanced `base_eval.py` to accept model tags for loading models during evaluation. - Improved handling of model tags to ensure proper checkpoint directory naming and logging.
This commit is contained in:
parent
8a6d34daf7
commit
9550053cc1
|
|
@ -238,6 +238,13 @@ python -m scripts.tok_eval
|
|||
export WANDB_EVAL_RUN=$WANDB_RUN
|
||||
export WANDB_PROJECT
|
||||
fi
|
||||
if [ "$WANDB_RUN" != "dummy" ]; then
|
||||
BASE_MODEL_TAG_FLAG="--model_tag=$WANDB_RUN"
|
||||
BASE_MODEL_TAG_FLAG_HYPHEN="--model-tag=$WANDB_RUN"
|
||||
else
|
||||
BASE_MODEL_TAG_FLAG=""
|
||||
BASE_MODEL_TAG_FLAG_HYPHEN=""
|
||||
fi
|
||||
|
||||
python -m scripts.base_train \
|
||||
--depth=$BASE_DEPTH \
|
||||
|
|
@ -251,7 +258,8 @@ python -m scripts.tok_eval
|
|||
--eval_tokens=$EVAL_TOKENS \
|
||||
--core_metric_every=-1 \
|
||||
--sample_every=-1 \
|
||||
--checkpoint_every_steps=$BASE_CHECKPOINT_STEPS
|
||||
--checkpoint_every_steps=$BASE_CHECKPOINT_STEPS \
|
||||
$BASE_MODEL_TAG_FLAG
|
||||
|
||||
if [ "$WANDB_RUN" != "dummy" ]; then
|
||||
unset WANDB_RUN_ID
|
||||
|
|
@ -259,8 +267,8 @@ python -m scripts.tok_eval
|
|||
fi
|
||||
|
||||
if [ "$RUN_STAGE_EVALS" = "1" ]; then
|
||||
python -m scripts.base_loss --device_batch_size=$DEVICE_BATCH --split_tokens=$EVAL_TOKENS
|
||||
python -m scripts.base_eval --max-per-task=16
|
||||
python -m scripts.base_loss --device_batch_size=$DEVICE_BATCH --split_tokens=$EVAL_TOKENS $BASE_MODEL_TAG_FLAG
|
||||
python -m scripts.base_eval --max-per-task=16 $BASE_MODEL_TAG_FLAG_HYPHEN
|
||||
fi
|
||||
fi
|
||||
|
||||
|
|
@ -274,6 +282,15 @@ if (( RUN_MID )); then
|
|||
export WANDB_EVAL_RUN="${WANDB_RUN}-mid"
|
||||
export WANDB_PROJECT
|
||||
fi
|
||||
if [ "$WANDB_RUN" != "dummy" ]; then
|
||||
MID_MODEL_TAG_FLAG="--model_tag=$WANDB_RUN"
|
||||
MID_OUTPUT_MODEL_TAG_FLAG="--output_model_tag=${WANDB_RUN}-mid"
|
||||
MID_EVAL_MODEL_TAG_FLAG="--model_tag=${WANDB_RUN}-mid"
|
||||
else
|
||||
MID_MODEL_TAG_FLAG=""
|
||||
MID_OUTPUT_MODEL_TAG_FLAG=""
|
||||
MID_EVAL_MODEL_TAG_FLAG=""
|
||||
fi
|
||||
|
||||
python -m scripts.mid_train \
|
||||
--max_seq_len=$SEQ_LEN \
|
||||
|
|
@ -283,7 +300,9 @@ if (( RUN_MID )); then
|
|||
--eval_every=$EVAL_STEPS \
|
||||
--eval_tokens=$EVAL_TOKENS \
|
||||
--checkpoint_every_steps=$MID_CHECKPOINT_STEPS \
|
||||
--num_iterations=$MID_NUM_STEPS
|
||||
--num_iterations=$MID_NUM_STEPS \
|
||||
$MID_MODEL_TAG_FLAG \
|
||||
$MID_OUTPUT_MODEL_TAG_FLAG
|
||||
if [ "$WANDB_RUN" != "dummy" ]; then
|
||||
unset WANDB_RUN_ID
|
||||
unset WANDB_EVAL_RUN
|
||||
|
|
@ -291,7 +310,7 @@ if (( RUN_MID )); then
|
|||
if [ "$RUN_STAGE_EVALS" = "1" ]; then
|
||||
# eval results will be terrible, this is just to execute the code paths.
|
||||
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
|
||||
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
|
||||
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20 $MID_EVAL_MODEL_TAG_FLAG
|
||||
fi
|
||||
fi
|
||||
|
||||
|
|
@ -305,6 +324,13 @@ if (( RUN_SFT )); then
|
|||
export WANDB_EVAL_RUN="${WANDB_RUN}-sft"
|
||||
export WANDB_PROJECT
|
||||
fi
|
||||
if [ "$WANDB_RUN" != "dummy" ]; then
|
||||
SFT_MODEL_TAG_FLAG="--model_tag=${WANDB_RUN}-mid"
|
||||
SFT_OUTPUT_MODEL_TAG_FLAG="--output_model_tag=${WANDB_RUN}-sft"
|
||||
else
|
||||
SFT_MODEL_TAG_FLAG=""
|
||||
SFT_OUTPUT_MODEL_TAG_FLAG=""
|
||||
fi
|
||||
|
||||
python -m scripts.chat_sft \
|
||||
--device_batch_size=$SFT_DEVICE_BATCH \
|
||||
|
|
@ -315,7 +341,9 @@ if (( RUN_SFT )); then
|
|||
--eval_steps=$SFT_EVAL_STEPS \
|
||||
--eval_metrics_every=$SFT_EVAL_METRICS_EVERY \
|
||||
--eval_metrics_max_problems=$SFT_EVAL_METRICS_MAX \
|
||||
--checkpoint_every_steps=$SFT_CHECKPOINT_STEPS
|
||||
--checkpoint_every_steps=$SFT_CHECKPOINT_STEPS \
|
||||
$SFT_MODEL_TAG_FLAG \
|
||||
$SFT_OUTPUT_MODEL_TAG_FLAG
|
||||
|
||||
if [ "$WANDB_RUN" != "dummy" ]; then
|
||||
unset WANDB_RUN_ID
|
||||
|
|
|
|||
|
|
@ -125,6 +125,7 @@ def main():
|
|||
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
|
||||
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
|
||||
parser.add_argument('--model-step', type=int, default=None, help='Checkpoint step to evaluate when using local models')
|
||||
parser.add_argument('--model-tag', type=str, default=None, help='Model tag to load from base_checkpoints when evaluating local models')
|
||||
args = parser.parse_args()
|
||||
|
||||
# distributed / precision setup
|
||||
|
|
@ -144,7 +145,7 @@ def main():
|
|||
model_slug = hf_path.replace("/", "-") # for the output csv file
|
||||
else:
|
||||
# load a local model from the file system
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", step=args.model_step)
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
||||
model_name = f"base_model (step {meta['step']})" # just for logging
|
||||
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
|
||||
if use_wandb:
|
||||
|
|
|
|||
|
|
@ -105,11 +105,21 @@ num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here i
|
|||
assert kv_head_mult >= 1, "kv_head_mult must be >= 1"
|
||||
assert num_heads % kv_head_mult == 0, f"num_heads ({num_heads}) must be divisible by kv_head_mult ({kv_head_mult})"
|
||||
num_kv_heads = max(1, num_heads // kv_head_mult)
|
||||
def _resolve_checkpoint_tag(tag, run_name, depth_value):
|
||||
if tag:
|
||||
return tag
|
||||
run_name = run_name or ""
|
||||
if run_name and run_name != "dummy":
|
||||
return run_name
|
||||
return f"d{depth_value}"
|
||||
model_tag = _resolve_checkpoint_tag(model_tag, run, depth)
|
||||
user_config["model_tag"] = model_tag
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim}")
|
||||
print0(f"kv_head_mult: {kv_head_mult}")
|
||||
print0(f"num_heads: {num_heads}")
|
||||
print0(f"num_kv_heads: {num_kv_heads}")
|
||||
print0(f"Checkpoint tag: {model_tag}")
|
||||
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
# figure out the needed gradient accumulation to reach the desired total batch size
|
||||
|
|
@ -172,8 +182,7 @@ build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size,
|
|||
x, y = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# Checkpoint output location
|
||||
checkpoint_dirname = model_tag if model_tag else f"d{depth}"
|
||||
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", checkpoint_dirname)
|
||||
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", model_tag)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Set up hyperparameter schedulers
|
||||
|
|
@ -209,8 +218,6 @@ total_sequences_seen = 0
|
|||
last_val_bpb = None
|
||||
|
||||
def save_base_checkpoint(step_idx):
|
||||
output_dirname = model_tag if model_tag else f"d{depth}"
|
||||
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
||||
optimizer_state = [opt.state_dict() for opt in optimizers]
|
||||
meta = {
|
||||
"step": step_idx,
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ run = "dummy" # wandb run name default ("dummy" is special - we won't log to wan
|
|||
source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
output_model_tag = "" # optional override for the checkpoint directory where SFT checkpoints are saved
|
||||
# compute/precision
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
dtype = "bfloat16"
|
||||
|
|
@ -102,6 +103,13 @@ orig_model = model # original, uncompiled model
|
|||
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
||||
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
def _resolve_checkpoint_tag(tag, run_name, fallback_tag):
|
||||
if tag:
|
||||
return tag
|
||||
run_name = run_name or ""
|
||||
if run_name and run_name != "dummy":
|
||||
return run_name
|
||||
return fallback_tag
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Task data mixture we'll train on
|
||||
|
|
@ -172,8 +180,10 @@ checkpoint_every_steps = int(checkpoint_every_steps)
|
|||
checkpoint_enabled = checkpoint_every_steps > 0
|
||||
|
||||
base_dir = get_base_dir()
|
||||
checkpoint_dirname = model_tag if model_tag else f"d{model.config.n_layer}"
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", checkpoint_dirname)
|
||||
fallback_checkpoint_tag = model_tag if model_tag else f"d{model.config.n_layer}"
|
||||
checkpoint_tag = _resolve_checkpoint_tag(output_model_tag, run, fallback_checkpoint_tag)
|
||||
print0(f"Checkpoint tag: {checkpoint_tag}")
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", checkpoint_tag)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ run = "dummy" # wandb run name default ("dummy" is special - we won't log to wan
|
|||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
output_model_tag = "" # optional override for the checkpoint directory name we save midtraining snapshots to
|
||||
dtype = "bfloat16"
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
max_seq_len = 2048
|
||||
|
|
@ -86,6 +87,14 @@ if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
|
|||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
def _resolve_checkpoint_tag(tag, run_name, depth_value):
|
||||
if tag:
|
||||
return tag
|
||||
run_name = run_name or ""
|
||||
if run_name and run_name != "dummy":
|
||||
return run_name
|
||||
return f"d{depth_value}"
|
||||
checkpoint_tag = _resolve_checkpoint_tag(output_model_tag, run, depth)
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
|
|
@ -110,6 +119,7 @@ for opt in optimizers:
|
|||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# Midtraining data mixture and DataLoader
|
||||
print0(f"Checkpoint tag: {checkpoint_tag}")
|
||||
base_dir = get_base_dir()
|
||||
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
|
||||
train_dataset = TaskMixture([
|
||||
|
|
@ -176,8 +186,7 @@ train_loader = mid_data_generator("train")
|
|||
build_val_loader = lambda: mid_data_generator("val")
|
||||
progress = 0 # will go from 0 to 1 over the course of the epoch
|
||||
|
||||
checkpoint_dirname = f"d{depth}"
|
||||
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", checkpoint_dirname)
|
||||
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", checkpoint_tag)
|
||||
|
||||
# Learning rate scheduler
|
||||
def get_lr_multiplier(progress):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user