diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index e6b8ea5..88f2749 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -69,8 +69,8 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sf # Load the model and tokenizer model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step) orig_model = model # original, uncompiled model -model = torch.compile(model, dynamic=False) # Fixed shapes enable efficient compilation -engine = Engine(model, tokenizer) # will be used for inline model evaluation only +# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs +engine = Engine(orig_model, tokenizer) # will be used for inline model evaluation only # ----------------------------------------------------------------------------- # Task data mixture we'll train on @@ -189,8 +189,8 @@ for step in range(num_iterations): metrics = {} with torch.no_grad(), autocast_ctx: # note that because these are inside no_grad, we can usually afford to at least ~2X the batch size - metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024) - metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024) + metrics["mmlu_acc"] = run_chat_eval("MMLU", orig_model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024) + metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", orig_model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024) metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items()) print0(f"Step {step:05d} | {metrics_str}") wandb_run.log({ @@ -248,7 +248,7 @@ if master_process: save_checkpoint( checkpoint_dir, step, - model.state_dict(), + orig_model.state_dict(), None, # note: we don't bother to save the optimizer state { "step": step,