diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 208acd1..324ec69 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -49,12 +49,11 @@ def has_ve(layer_idx, n_layer): return layer_idx % 2 == (n_layer - 1) % 2 def apply_rotary_emb(x, cos, sin): - assert x.ndim == 4 # multihead attention - d = x.shape[3] // 2 - x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves - y1 = x1 * cos + x2 * sin # rotate pairs of dims + assert x.ndim == 4 # (B, T, H, D) multihead attention layout + x1, x2 = x.chunk(2, dim=-1) # split head_dim into two halves + y1 = x1 * cos + x2 * sin y2 = x1 * (-sin) + x2 * cos - return torch.cat([y1, y2], 3) + return torch.cat([y1, y2], dim=-1) class CausalSelfAttention(nn.Module): def __init__(self, config, layer_idx):