mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-15 10:39:08 +00:00
brodcast logid to get speedup.
This commit is contained in:
parent
275fa4b060
commit
aef1254205
|
|
@ -108,17 +108,21 @@ def XSA(atten, value, sdpa=False):
|
|||
vB, vH, vT, vD = value.shape
|
||||
rep = aH // vH
|
||||
assert aH % vH == 0, "query heads and kv heads must be divisible"
|
||||
value = torch.repeat_interleave(value, rep, dim=1)
|
||||
Vn = F.normalize(value, dim=-1)
|
||||
return atten - torch.sum(atten * Vn, dim=-1, keepdim=True) * Vn
|
||||
|
||||
atten_r = atten.reshape(aB, vH, rep, aT, aD)
|
||||
Vn = F.normalize(value, dim=-1).unsqueeze(2)
|
||||
atten_r = atten_r - torch.sum(atten_r * Vn, dim=-1, keepdim=True) * Vn
|
||||
return atten_r.reshape(aB, aH, aT, aD)
|
||||
else:
|
||||
aB, aT, aH, aD = atten.shape
|
||||
vB, vT, vH, vD = value.shape
|
||||
rep = aH // vH
|
||||
assert aH % vH == 0, "query heads and kv heads must be divisible"
|
||||
value = torch.repeat_interleave(value, rep, dim=2)
|
||||
Vn = F.normalize(value, dim=-1)
|
||||
return atten - torch.sum(atten * Vn, dim=-1, keepdim=True) * Vn
|
||||
|
||||
atten_r = atten.reshape(aB, aT, vH, rep, aD)
|
||||
Vn = F.normalize(value, dim=-1).unsqueeze(-2)
|
||||
atten_r = atten_r - torch.sum(atten_r * Vn, dim=-1, keepdim=True) * Vn
|
||||
return atten_r.reshape(aB, aT, aH, aD)
|
||||
|
||||
# =============================================================================
|
||||
# Public API: Same interface as FA3
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user