Compare commits

..

2 Commits

Author SHA1 Message Date
Andrej
4a87a0d19f
Merge pull request #299 from samjabrahams/rotary_embedding_head_dim_comment_cleanup
Fix comment: rotary embeddings final dimension size
2025-11-17 13:29:21 -08:00
Sam Abrahams
11e68bf442 Fix comment: rotary embeddings final dimension size 2025-11-17 11:32:56 -05:00
2 changed files with 25 additions and 25 deletions

View File

@ -244,7 +244,7 @@ class GPT(nn.Module):
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"

View File

@ -233,30 +233,6 @@ while True:
})
model.train()
# save checkpoint: at the last step, or every save_every steps, except at the first step or the resume step
if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0):
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(), # model parameters
[opt.state_dict() for opt in optimizers], # optimizer states
{ # metadata saved as json
"step": step,
"val_bpb": val_bpb, # loss at last step
"model_config": model_config_kwargs,
"user_config": user_config, # inputs to the training script
"device_batch_size": device_batch_size,
"max_seq_len": max_seq_len,
"dataloader_state_dict": dataloader_state_dict,
"loop_state": { # all loop state (other than step) so that we can resume training
"min_val_bpb": min_val_bpb,
"smooth_train_loss": smooth_train_loss,
"total_training_time": total_training_time,
},
},
rank=ddp_rank,
)
# once in a while: estimate the CORE metric (all ranks participate)
# use the original uncompiled model because the inputs keep changing shape
results = {}
@ -294,6 +270,30 @@ while True:
print0(tokenizer.decode(sample[0]))
model.train()
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0):
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(), # model parameters
[opt.state_dict() for opt in optimizers], # optimizer states
{ # metadata saved as json
"step": step,
"val_bpb": val_bpb, # loss at last step
"model_config": model_config_kwargs,
"user_config": user_config, # inputs to the training script
"device_batch_size": device_batch_size,
"max_seq_len": max_seq_len,
"dataloader_state_dict": dataloader_state_dict,
"loop_state": { # all loop state (other than step) so that we can resume training
"min_val_bpb": min_val_bpb,
"smooth_train_loss": smooth_train_loss,
"total_training_time": total_training_time,
},
},
rank=ddp_rank,
)
# termination conditions (TODO: possibly also add loss explosions etc.)
if last_step:
break