fix: enforce (B, H, T, D) layout for SDPA fallback to support CPU strictness

This commit is contained in:
hasan 2026-01-14 21:42:20 +01:00
parent 68e66be05c
commit d7fccbab82

View File

@ -28,9 +28,7 @@ import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain.
# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal)
from kernels import get_kernel
try:
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
except Exception:
@ -101,7 +99,14 @@ class CausalSelfAttention(nn.Module):
if flash_attn is not None:
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
else:
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
# Fallback Path (CPU/MPS): Needs Transpose to (B, H, T, D) for SDPA
y = F.scaled_dot_product_attention(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
is_causal=True
)
y = y.transpose(1, 2) # Restore layout to (B, T, H, D)
else:
# Inference: use flash_attn_with_kvcache which handles cache management
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
@ -129,14 +134,15 @@ class CausalSelfAttention(nn.Module):
for b in range(B):
pos = positions[b].item()
# Retrieve context & FORCE CAST to match query dtype (Fix for float vs bfloat16 crash)
k_curr = k_cache[b:b+1, :pos+T].to(dtype=q.dtype)
v_curr = v_cache[b:b+1, :pos+T].to(dtype=q.dtype)
q_curr = q[b:b+1]
# Fetch history + cast to q.dtype (MPS fix) + Transpose (CPU fix)
k_curr = k_cache[b:b+1, :pos+T].to(dtype=q.dtype).transpose(1, 2)
v_curr = v_cache[b:b+1, :pos+T].to(dtype=q.dtype).transpose(1, 2)
q_curr = q[b:b+1].transpose(1, 2)
# Standard Attention (is_causal=False because context is explicit)
att_out = F.scaled_dot_product_attention(q_curr, k_curr, v_curr, is_causal=False)
y[b] = att_out[0]
# Transpose back to (B, T, H, D) and store
y[b] = att_out.transpose(1, 2)[0]
# Advance position after last layer processes
if self.layer_idx == kv_cache.n_layers - 1:
@ -440,4 +446,4 @@ class GPT(nn.Module):
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
ids = torch.cat((ids, next_ids), dim=1)
token = next_ids.item()
yield token
yield token