perf: remove redundant dtype conversion in apply_rotary_emb

This commit is contained in:
spjosyula 2025-12-24 01:56:45 +05:30
parent bc51da8bac
commit 69cc309967

View File

@ -44,9 +44,7 @@ def apply_rotary_emb(x, cos, sin):
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
y1 = x1 * cos + x2 * sin # rotate pairs of dims
y2 = x1 * (-sin) + x2 * cos
out = torch.cat([y1, y2], 3) # re-assemble
out = out.to(x.dtype) # ensure input/output dtypes match
return out
return torch.cat([y1, y2], 3) # re-assemble
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):