This commit is contained in:
jolonf 2026-02-13 14:53:16 +01:00 committed by GitHub
commit 8810c3c68b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -416,6 +416,31 @@ while True:
})
model.train()
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
if last_step or (
step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0):
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(), # model parameters
optimizer.state_dict(), # optimizer state
{ # metadata saved as json
"step": step,
"val_bpb": val_bpb, # loss at last step
"model_config": model_config_kwargs,
"user_config": user_config, # inputs to the training script
"device_batch_size": args.device_batch_size,
"max_seq_len": args.max_seq_len,
"dataloader_state_dict": dataloader_state_dict,
"loop_state": { # all loop state (other than step) so that we can resume training
"min_val_bpb": min_val_bpb,
"smooth_train_loss": smooth_train_loss,
"total_training_time": total_training_time,
},
},
rank=ddp_rank,
)
# once in a while: estimate the CORE metric (all ranks participate)
# use the original uncompiled model because the inputs keep changing shape
# disable FP8 for evaluation to use BF16 for more consistent/accurate results
@ -454,30 +479,6 @@ while True:
print0(tokenizer.decode(sample[0]))
model.train()
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0):
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(), # model parameters
optimizer.state_dict(), # optimizer state
{ # metadata saved as json
"step": step,
"val_bpb": val_bpb, # loss at last step
"model_config": model_config_kwargs,
"user_config": user_config, # inputs to the training script
"device_batch_size": args.device_batch_size,
"max_seq_len": args.max_seq_len,
"dataloader_state_dict": dataloader_state_dict,
"loop_state": { # all loop state (other than step) so that we can resume training
"min_val_bpb": min_val_bpb,
"smooth_train_loss": smooth_train_loss,
"total_training_time": total_training_time,
},
},
rank=ddp_rank,
)
# termination conditions (TODO: possibly also add loss explosions etc.)
if last_step:
break