diff --git a/scripts/base_train.py b/scripts/base_train.py index 01edd1db..cde42a13 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -408,7 +408,7 @@ while True: # use the original uncompiled model because the inputs keep changing shape if args.sample_every > 0 and master_process and (last_step or (step > 0 and step % args.sample_every == 0)): model.eval() - evaluate_sample(model, tokenizer, lambda x:print0(x), True) + evaluate_sample(orig_model, tokenizer, lambda x:print0(x), True) model.train() # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step