mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Compare commits
2 Commits
4a87a0d19f
...
0d60d74dc0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0d60d74dc0 | ||
|
|
36d132eb23 |
|
|
@ -233,6 +233,30 @@ while True:
|
|||
})
|
||||
model.train()
|
||||
|
||||
# save checkpoint: at the last step, or every save_every steps, except at the first step or the resume step
|
||||
if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0):
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(), # model parameters
|
||||
[opt.state_dict() for opt in optimizers], # optimizer states
|
||||
{ # metadata saved as json
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": model_config_kwargs,
|
||||
"user_config": user_config, # inputs to the training script
|
||||
"device_batch_size": device_batch_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
"dataloader_state_dict": dataloader_state_dict,
|
||||
"loop_state": { # all loop state (other than step) so that we can resume training
|
||||
"min_val_bpb": min_val_bpb,
|
||||
"smooth_train_loss": smooth_train_loss,
|
||||
"total_training_time": total_training_time,
|
||||
},
|
||||
},
|
||||
rank=ddp_rank,
|
||||
)
|
||||
|
||||
# once in a while: estimate the CORE metric (all ranks participate)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
results = {}
|
||||
|
|
@ -270,30 +294,6 @@ while True:
|
|||
print0(tokenizer.decode(sample[0]))
|
||||
model.train()
|
||||
|
||||
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
|
||||
if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0):
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(), # model parameters
|
||||
[opt.state_dict() for opt in optimizers], # optimizer states
|
||||
{ # metadata saved as json
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": model_config_kwargs,
|
||||
"user_config": user_config, # inputs to the training script
|
||||
"device_batch_size": device_batch_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
"dataloader_state_dict": dataloader_state_dict,
|
||||
"loop_state": { # all loop state (other than step) so that we can resume training
|
||||
"min_val_bpb": min_val_bpb,
|
||||
"smooth_train_loss": smooth_train_loss,
|
||||
"total_training_time": total_training_time,
|
||||
},
|
||||
},
|
||||
rank=ddp_rank,
|
||||
)
|
||||
|
||||
# termination conditions (TODO: possibly also add loss explosions etc.)
|
||||
if last_step:
|
||||
break
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user