mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-11 04:59:50 +00:00
Merge 07e5509662 into 4610a838a1
This commit is contained in:
commit
a88a94f5bd
|
|
@ -12,6 +12,7 @@ python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 -
|
|||
"""
|
||||
|
||||
import os
|
||||
import signal
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import argparse
|
||||
import time
|
||||
|
|
@ -68,6 +69,7 @@ parser.add_argument("--core_metric_every", type=int, default=2000, help="evaluat
|
|||
parser.add_argument("--core_metric_max_per_task", type=int, default=500, help="examples per task for CORE metric")
|
||||
parser.add_argument("--sample_every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
|
||||
parser.add_argument("--save_every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
|
||||
parser.add_argument("--save_on_sigint", action="store_true", help="save a checkpoint when exiting prematurely")
|
||||
# Output
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="override model tag for checkpoint directory name")
|
||||
args = parser.parse_args()
|
||||
|
|
@ -258,6 +260,38 @@ else:
|
|||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
|
||||
|
||||
def save_on_interrupt(sig,frame):
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(), # model parameters
|
||||
[opt.state_dict() for opt in optimizers], # optimizer states
|
||||
{ # metadata saved as json
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": model_config_kwargs,
|
||||
"user_config": user_config, # inputs to the training script
|
||||
"device_batch_size": args.device_batch_size,
|
||||
"max_seq_len": args.max_seq_len,
|
||||
"dataloader_state_dict": dataloader_state_dict,
|
||||
"loop_state": { # all loop state (other than step) so that we can resume training
|
||||
"min_val_bpb": min_val_bpb,
|
||||
"smooth_train_loss": smooth_train_loss,
|
||||
"total_training_time": total_training_time,
|
||||
},
|
||||
},
|
||||
rank=ddp_rank,
|
||||
)
|
||||
compute_cleanup()
|
||||
print("So long and thanks for all the fish!")
|
||||
exit()
|
||||
|
||||
if args.save_on_sigint:
|
||||
signal.signal(signal.SIGINT, save_on_interrupt)
|
||||
|
||||
|
||||
while True:
|
||||
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
|
||||
flops_so_far = num_flops_per_token * args.total_batch_size * step
|
||||
|
|
@ -412,6 +446,7 @@ while True:
|
|||
# state update
|
||||
step += 1
|
||||
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user