mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-09 12:09:49 +00:00
perf: remove redundant dtype conversion in apply_rotary_emb
This commit is contained in:
parent
bc51da8bac
commit
69cc309967
|
|
@ -44,9 +44,7 @@ def apply_rotary_emb(x, cos, sin):
|
|||
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
|
||||
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
||||
y2 = x1 * (-sin) + x2 * cos
|
||||
out = torch.cat([y1, y2], 3) # re-assemble
|
||||
out = out.to(x.dtype) # ensure input/output dtypes match
|
||||
return out
|
||||
return torch.cat([y1, y2], 3) # re-assemble
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user