This commit is contained in:
suspicious-pineapple 2026-01-12 07:15:53 +01:00 committed by GitHub
commit a88a94f5bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,6 +12,7 @@ python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 -
"""
import os
import signal
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import argparse
import time
@ -68,6 +69,7 @@ parser.add_argument("--core_metric_every", type=int, default=2000, help="evaluat
parser.add_argument("--core_metric_max_per_task", type=int, default=500, help="examples per task for CORE metric")
parser.add_argument("--sample_every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
parser.add_argument("--save_every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
parser.add_argument("--save_on_sigint", action="store_true", help="save a checkpoint when exiting prematurely")
# Output
parser.add_argument("--model_tag", type=str, default=None, help="override model tag for checkpoint directory name")
args = parser.parse_args()
@ -258,6 +260,38 @@ else:
# -----------------------------------------------------------------------------
# Training loop
def save_on_interrupt(sig,frame):
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(), # model parameters
[opt.state_dict() for opt in optimizers], # optimizer states
{ # metadata saved as json
"step": step,
"val_bpb": val_bpb, # loss at last step
"model_config": model_config_kwargs,
"user_config": user_config, # inputs to the training script
"device_batch_size": args.device_batch_size,
"max_seq_len": args.max_seq_len,
"dataloader_state_dict": dataloader_state_dict,
"loop_state": { # all loop state (other than step) so that we can resume training
"min_val_bpb": min_val_bpb,
"smooth_train_loss": smooth_train_loss,
"total_training_time": total_training_time,
},
},
rank=ddp_rank,
)
compute_cleanup()
print("So long and thanks for all the fish!")
exit()
if args.save_on_sigint:
signal.signal(signal.SIGINT, save_on_interrupt)
while True:
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
flops_so_far = num_flops_per_token * args.total_batch_size * step
@ -412,6 +446,7 @@ while True:
# state update
step += 1
# print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")