mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-27 15:15:16 +00:00
Fix rotation calculations in apply_rotary_emb function
Reading some blogs on RoPE, it feels like the current implementation is a little off? Or am I missing something?
This commit is contained in:
parent
8b4849d548
commit
79b7b04ca0
|
|
@ -52,8 +52,8 @@ def apply_rotary_emb(x, cos, sin):
|
|||
assert x.ndim == 4 # multihead attention
|
||||
d = x.shape[3] // 2
|
||||
x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
|
||||
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
||||
y2 = x1 * (-sin) + x2 * cos
|
||||
y1 = x1 * cos - x2 * sin # rotate pairs of dims
|
||||
y2 = x1 * sin + x2 * cos
|
||||
return torch.cat([y1, y2], 3)
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user