diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 9466e27..90b5d98 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -23,7 +23,6 @@ from nanochat.common import get_dist_info from nanochat.muon import Muon, DistMuon from nanochat.adamw import DistAdamW - @dataclass class GPTConfig: sequence_len: int = 1024 @@ -49,7 +48,6 @@ def apply_rotary_emb(x, cos, sin): out = out.to(x.dtype) # ensure input/output dtypes match return out - class CausalSelfAttention(nn.Module): def __init__(self, config, layer_idx): super().__init__()