Split with chunk

This commit is contained in:
Chris McCormick 2026-02-02 08:57:40 -08:00
parent 6d38e4bc88
commit 5129a34288

View File

@ -49,12 +49,11 @@ def has_ve(layer_idx, n_layer):
return layer_idx % 2 == (n_layer - 1) % 2
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
y1 = x1 * cos + x2 * sin # rotate pairs of dims
assert x.ndim == 4 # (B, T, H, D) multihead attention layout
x1, x2 = x.chunk(2, dim=-1) # split head_dim into two halves
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat([y1, y2], 3)
return torch.cat([y1, y2], dim=-1)
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):