diff --git a/scripts/base_train.py b/scripts/base_train.py index c7c5bba..55b9ca6 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -12,6 +12,7 @@ python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 - """ import os +import signal os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import argparse import time @@ -68,6 +69,7 @@ parser.add_argument("--core_metric_every", type=int, default=2000, help="evaluat parser.add_argument("--core_metric_max_per_task", type=int, default=500, help="examples per task for CORE metric") parser.add_argument("--sample_every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") parser.add_argument("--save_every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") +parser.add_argument("--save_on_sigint", action="store_true", help="save a checkpoint when exiting prematurely") # Output parser.add_argument("--model_tag", type=str, default=None, help="override model tag for checkpoint directory name") args = parser.parse_args() @@ -258,6 +260,38 @@ else: # ----------------------------------------------------------------------------- # Training loop + + +def save_on_interrupt(sig,frame): + save_checkpoint( + checkpoint_dir, + step, + orig_model.state_dict(), # model parameters + [opt.state_dict() for opt in optimizers], # optimizer states + { # metadata saved as json + "step": step, + "val_bpb": val_bpb, # loss at last step + "model_config": model_config_kwargs, + "user_config": user_config, # inputs to the training script + "device_batch_size": args.device_batch_size, + "max_seq_len": args.max_seq_len, + "dataloader_state_dict": dataloader_state_dict, + "loop_state": { # all loop state (other than step) so that we can resume training + "min_val_bpb": min_val_bpb, + "smooth_train_loss": smooth_train_loss, + "total_training_time": total_training_time, + }, + }, + rank=ddp_rank, + ) + compute_cleanup() + print("So long and thanks for all the fish!") + exit() + +if args.save_on_sigint: + signal.signal(signal.SIGINT, save_on_interrupt) + + while True: last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end flops_so_far = num_flops_per_token * args.total_batch_size * step @@ -412,6 +446,7 @@ while True: # state update step += 1 + # print a few more stats print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m")