diff --git a/scripts/base_train.py b/scripts/base_train.py index 2d61477..b81d374 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -296,6 +296,30 @@ while True: }) model.train() + # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step + if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0): + save_checkpoint( + checkpoint_dir, + step, + orig_model.state_dict(), # model parameters + [opt.state_dict() for opt in optimizers], # optimizer states + { # metadata saved as json + "step": step, + "val_bpb": val_bpb, # loss at last step + "model_config": model_config_kwargs, + "user_config": user_config, # inputs to the training script + "device_batch_size": args.device_batch_size, + "max_seq_len": args.max_seq_len, + "dataloader_state_dict": dataloader_state_dict, + "loop_state": { # all loop state (other than step) so that we can resume training + "min_val_bpb": min_val_bpb, + "smooth_train_loss": smooth_train_loss, + "total_training_time": total_training_time, + }, + }, + rank=ddp_rank, + ) + # once in a while: estimate the CORE metric (all ranks participate) # use the original uncompiled model because the inputs keep changing shape results = {} @@ -333,30 +357,6 @@ while True: print0(tokenizer.decode(sample[0])) model.train() - # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step - if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0): - save_checkpoint( - checkpoint_dir, - step, - orig_model.state_dict(), # model parameters - [opt.state_dict() for opt in optimizers], # optimizer states - { # metadata saved as json - "step": step, - "val_bpb": val_bpb, # loss at last step - "model_config": model_config_kwargs, - "user_config": user_config, # inputs to the training script - "device_batch_size": args.device_batch_size, - "max_seq_len": args.max_seq_len, - "dataloader_state_dict": dataloader_state_dict, - "loop_state": { # all loop state (other than step) so that we can resume training - "min_val_bpb": min_val_bpb, - "smooth_train_loss": smooth_train_loss, - "total_training_time": total_training_time, - }, - }, - rank=ddp_rank, - ) - # termination conditions (TODO: possibly also add loss explosions etc.) if last_step: break