Merge pull request #10 from Dianababaei/refactor/chat-sft-use-orig-model-for-eval-and-checkpointing

Update chat supervised fine-tuning script with improved training configuration and enhanced pipeline functionality
This commit is contained in:
Dianababaei 2025-11-05 19:42:13 +03:30 committed by GitHub
commit 49d29417f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -69,8 +69,8 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sf
# Load the model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
orig_model = model # original, uncompiled model
model = torch.compile(model, dynamic=False) # Fixed shapes enable efficient compilation
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
engine = Engine(orig_model, tokenizer) # will be used for inline model evaluation only
# -----------------------------------------------------------------------------
# Task data mixture we'll train on
@ -189,8 +189,8 @@ for step in range(num_iterations):
metrics = {}
with torch.no_grad(), autocast_ctx:
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
metrics["mmlu_acc"] = run_chat_eval("MMLU", orig_model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", orig_model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
print0(f"Step {step:05d} | {metrics_str}")
wandb_run.log({
@ -248,7 +248,7 @@ if master_process:
save_checkpoint(
checkpoint_dir,
step,
model.state_dict(),
orig_model.state_dict(),
None, # note: we don't bother to save the optimizer state
{
"step": step,