diff --git a/scripts/base_train.py b/scripts/base_train.py index 7ed6330..a1adbb9 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -28,7 +28,7 @@ from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine from nanochat.flash_attention import HAS_FA3 -from scripts.base_eval import evaluate_model +from scripts.base_eval import evaluate_core print_banner() # ----------------------------------------------------------------------------- @@ -305,7 +305,7 @@ while True: if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)): model.eval() with autocast_ctx: - results = evaluate_model(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task) + results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task) print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") wandb_run.log({ "step": step,