mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-07 21:02:15 +00:00
Merge pull request #10 from Dianababaei/refactor/chat-sft-use-orig-model-for-eval-and-checkpointing
Update chat supervised fine-tuning script with improved training configuration and enhanced pipeline functionality
This commit is contained in:
commit
49d29417f1
|
|
@ -69,8 +69,8 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sf
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
||||||
orig_model = model # original, uncompiled model
|
orig_model = model # original, uncompiled model
|
||||||
model = torch.compile(model, dynamic=False) # Fixed shapes enable efficient compilation
|
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
||||||
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
engine = Engine(orig_model, tokenizer) # will be used for inline model evaluation only
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Task data mixture we'll train on
|
# Task data mixture we'll train on
|
||||||
|
|
@ -189,8 +189,8 @@ for step in range(num_iterations):
|
||||||
metrics = {}
|
metrics = {}
|
||||||
with torch.no_grad(), autocast_ctx:
|
with torch.no_grad(), autocast_ctx:
|
||||||
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
||||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
metrics["mmlu_acc"] = run_chat_eval("MMLU", orig_model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", orig_model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||||
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
||||||
print0(f"Step {step:05d} | {metrics_str}")
|
print0(f"Step {step:05d} | {metrics_str}")
|
||||||
wandb_run.log({
|
wandb_run.log({
|
||||||
|
|
@ -248,7 +248,7 @@ if master_process:
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
checkpoint_dir,
|
checkpoint_dir,
|
||||||
step,
|
step,
|
||||||
model.state_dict(),
|
orig_model.state_dict(),
|
||||||
None, # note: we don't bother to save the optimizer state
|
None, # note: we don't bother to save the optimizer state
|
||||||
{
|
{
|
||||||
"step": step,
|
"step": step,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user