From ddf96c17c5244b3b9345af9e306beb2a764c0c5d Mon Sep 17 00:00:00 2001 From: jolonf <22928840+jolonf@users.noreply.github.com> Date: Sun, 18 Jan 2026 21:46:35 +1100 Subject: [PATCH] Fix for issue #446 - moved `save_checkpoint()` above `evaluate_model()` so that the checkpoint is saved before the evals are run. --- scripts/base_train.py | 48 +++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index e051f99..95f1e26 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -298,6 +298,30 @@ while True: }) model.train() + # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step + if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0): + save_checkpoint( + checkpoint_dir, + step, + orig_model.state_dict(), # model parameters + [opt.state_dict() for opt in optimizers], # optimizer states + { # metadata saved as json + "step": step, + "val_bpb": val_bpb, # loss at last step + "model_config": model_config_kwargs, + "user_config": user_config, # inputs to the training script + "device_batch_size": args.device_batch_size, + "max_seq_len": args.max_seq_len, + "dataloader_state_dict": dataloader_state_dict, + "loop_state": { # all loop state (other than step) so that we can resume training + "min_val_bpb": min_val_bpb, + "smooth_train_loss": smooth_train_loss, + "total_training_time": total_training_time, + }, + }, + rank=ddp_rank, + ) + # once in a while: estimate the CORE metric (all ranks participate) # use the original uncompiled model because the inputs keep changing shape results = {} @@ -335,30 +359,6 @@ while True: print0(tokenizer.decode(sample[0])) model.train() - # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step - if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0): - save_checkpoint( - checkpoint_dir, - step, - orig_model.state_dict(), # model parameters - [opt.state_dict() for opt in optimizers], # optimizer states - { # metadata saved as json - "step": step, - "val_bpb": val_bpb, # loss at last step - "model_config": model_config_kwargs, - "user_config": user_config, # inputs to the training script - "device_batch_size": args.device_batch_size, - "max_seq_len": args.max_seq_len, - "dataloader_state_dict": dataloader_state_dict, - "loop_state": { # all loop state (other than step) so that we can resume training - "min_val_bpb": min_val_bpb, - "smooth_train_loss": smooth_train_loss, - "total_training_time": total_training_time, - }, - }, - rank=ddp_rank, - ) - # termination conditions (TODO: possibly also add loss explosions etc.) if last_step: break