From 36d132eb23c7799f90c4e54c42026c953693fb90 Mon Sep 17 00:00:00 2001 From: Nitish Pandey Date: Sun, 16 Nov 2025 06:48:13 +0530 Subject: [PATCH 1/2] save checkpoint before possible OOM in CORE metric --- scripts/base_train.py | 49 ++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index c9ea6c9..4329cc6 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -233,6 +233,31 @@ while True: }) model.train() + # save checkpoint: at the last step, or every save_every steps, except at the first step or the resume step + # save progress before possible OOM in CORE metric + if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0): + save_checkpoint( + checkpoint_dir, + step, + orig_model.state_dict(), # model parameters + [opt.state_dict() for opt in optimizers], # optimizer states + { # metadata saved as json + "step": step, + "val_bpb": val_bpb, # loss at last step + "model_config": model_config_kwargs, + "user_config": user_config, # inputs to the training script + "device_batch_size": device_batch_size, + "max_seq_len": max_seq_len, + "dataloader_state_dict": dataloader_state_dict, + "loop_state": { # all loop state (other than step) so that we can resume training + "min_val_bpb": min_val_bpb, + "smooth_train_loss": smooth_train_loss, + "total_training_time": total_training_time, + }, + }, + rank=ddp_rank, + ) + # once in a while: estimate the CORE metric (all ranks participate) # use the original uncompiled model because the inputs keep changing shape results = {} @@ -270,30 +295,6 @@ while True: print0(tokenizer.decode(sample[0])) model.train() - # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step - if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0): - save_checkpoint( - checkpoint_dir, - step, - orig_model.state_dict(), # model parameters - [opt.state_dict() for opt in optimizers], # optimizer states - { # metadata saved as json - "step": step, - "val_bpb": val_bpb, # loss at last step - "model_config": model_config_kwargs, - "user_config": user_config, # inputs to the training script - "device_batch_size": device_batch_size, - "max_seq_len": max_seq_len, - "dataloader_state_dict": dataloader_state_dict, - "loop_state": { # all loop state (other than step) so that we can resume training - "min_val_bpb": min_val_bpb, - "smooth_train_loss": smooth_train_loss, - "total_training_time": total_training_time, - }, - }, - rank=ddp_rank, - ) - # termination conditions (TODO: possibly also add loss explosions etc.) if last_step: break From 0d60d74dc0a4c3e4b521365c7a547d0ff38da6b3 Mon Sep 17 00:00:00 2001 From: Nitish Pandey Date: Sun, 16 Nov 2025 06:50:41 +0530 Subject: [PATCH 2/2] update comment --- scripts/base_train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index 4329cc6..080ca38 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -234,7 +234,6 @@ while True: model.train() # save checkpoint: at the last step, or every save_every steps, except at the first step or the resume step - # save progress before possible OOM in CORE metric if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0): save_checkpoint( checkpoint_dir,