This commit is contained in:
fpvsim 2026-03-05 17:40:47 +01:00 committed by GitHub
commit 96a6d7f280
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -58,8 +58,8 @@ def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
y1 = x1 * cos + x2 * sin # rotate pairs of dims
y2 = x1 * (-sin) + x2 * cos
y1 = x1 * cos - x2 * sin # rotate pairs of dims
y2 = x1 * sin + x2 * cos
return torch.cat([y1, y2], 3)
class CausalSelfAttention(nn.Module):