diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 8389deb..b5ba49a 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -11,7 +11,6 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft import os os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -import copy import wandb import torch @@ -23,11 +22,9 @@ from nanochat.checkpoint_manager import save_checkpoint from nanochat.engine import Engine from scripts.chat_eval import run_chat_eval -from tasks.common import TaskMixture, TaskSequence -from tasks.mmlu import MMLU +from tasks.common import TaskMixture from tasks.arc import ARC from tasks.gsm8k import GSM8K -from tasks.humaneval import HumanEval from tasks.smoltalk import SmolTalk # ----------------------------------------------------------------------------- @@ -186,7 +183,7 @@ for step in range(num_iterations): }) model.train() - # evlauate MMLU accuracy + # evlauate accuracy of the multiple choice tasks (which are quick to run) if last_step or (step > 0 and step % eval_metrics_every == 0): model.eval() metrics = {} @@ -194,8 +191,6 @@ for step in range(num_iterations): # note that because these are inside no_grad, we can usually afford to at least ~2X the batch size metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024) metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024) - metrics["gsm8k_acc"] = run_chat_eval("GSM8K", model, tokenizer, engine, max_problems=64) - metrics["humaneval_acc"] = run_chat_eval("HumanEval", model, tokenizer, engine, max_problems=64) metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items()) print0(f"Step {step:05d} | {metrics_str}") wandb_run.log({