Compare commits

...

4 Commits

Author SHA1 Message Date
burtenshaw
dae01075f0
Merge 919ea572b0 into 4a87a0d19f 2025-11-17 23:19:15 -04:00
Andrej
4a87a0d19f
Merge pull request #299 from samjabrahams/rotary_embedding_head_dim_comment_cleanup
Fix comment: rotary embeddings final dimension size
2025-11-17 13:29:21 -08:00
Sam Abrahams
11e68bf442 Fix comment: rotary embeddings final dimension size 2025-11-17 11:32:56 -05:00
burtenshaw
919ea572b0 add modular imports to the init 2025-10-14 14:55:03 +02:00
2 changed files with 34 additions and 1 deletions

View File

@ -0,0 +1,33 @@
# Import all submodules used by scripts
from . import common
from . import tokenizer
from . import checkpoint_manager
from . import core_eval
from . import gpt
from . import dataloader
from . import loss_eval
from . import engine
from . import dataset
from . import report
from . import adamw
from . import muon
from . import configurator
from . import execution
# Make submodules available
__all__ = [
"common",
"tokenizer",
"checkpoint_manager",
"core_eval",
"gpt",
"dataloader",
"loss_eval",
"engine",
"dataset",
"report",
"adamw",
"muon",
"configurator",
"execution",
]

View File

@ -244,7 +244,7 @@ class GPT(nn.Module):
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"