fix: add Flash Attention 3 fallback for MPS/CPU inference

This commit is contained in:
hasan 2026-01-14 01:10:29 +01:00
parent 7312ec9898
commit c9c01ffe04

View File

@ -28,8 +28,8 @@ import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain.
# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal)
from kernels import get_kernel
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
flash_attn = None
@dataclass
class GPTConfig:
@ -92,17 +92,46 @@ class CausalSelfAttention(nn.Module):
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
if kv_cache is None:
# Training: causal attention with optional sliding window
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
if flash_attn is not None:
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
else:
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
else:
# Inference: use flash_attn_with_kvcache which handles cache management
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
y = flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache,
k=k, v=v,
cache_seqlens=kv_cache.cache_seqlens,
causal=True,
window_size=window_size,
)
if flash_attn is not None:
# Optimized Path (Linux/CUDA with FA3)
y = flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache,
k=k, v=v,
cache_seqlens=kv_cache.cache_seqlens,
causal=True,
window_size=window_size,
)
else:
# Fallback Path (macOS/MPS or CPU) - Manual Cache Update + SDPA
positions = kv_cache.cache_seqlens
# Update cache manually for the batch
for b in range(B):
pos = positions[b].item()
k_cache[b, pos:pos+T] = k[b]
v_cache[b, pos:pos+T] = v[b]
# Compute attention manually
y = torch.empty_like(q)
for b in range(B):
pos = positions[b].item()
# Retrieve context & FORCE CAST to match query dtype (Fix for float vs bfloat16 crash)
k_curr = k_cache[b:b+1, :pos+T].to(dtype=q.dtype)
v_curr = v_cache[b:b+1, :pos+T].to(dtype=q.dtype)
q_curr = q[b:b+1]
# Standard Attention (is_causal=False because context is explicit)
att_out = F.scaled_dot_product_attention(q_curr, k_curr, v_curr, is_causal=False)
y[b] = att_out[0]
# Advance position after last layer processes
if self.layer_idx == kv_cache.n_layers - 1:
kv_cache.advance(T)