From c9c01ffe04175f19df15688feea4bd12d6bace0a Mon Sep 17 00:00:00 2001 From: hasan Date: Wed, 14 Jan 2026 01:10:29 +0100 Subject: [PATCH] fix: add Flash Attention 3 fallback for MPS/CPU inference --- nanochat/gpt.py | 49 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 81ccb0c..38ba153 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -28,8 +28,8 @@ import os os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" # Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain. # Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal) -from kernels import get_kernel -flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface + +flash_attn = None @dataclass class GPTConfig: @@ -92,17 +92,46 @@ class CausalSelfAttention(nn.Module): # window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context if kv_cache is None: # Training: causal attention with optional sliding window - y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size) + if flash_attn is not None: + y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size) + else: + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) else: # Inference: use flash_attn_with_kvcache which handles cache management k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx) - y = flash_attn.flash_attn_with_kvcache( - q, k_cache, v_cache, - k=k, v=v, - cache_seqlens=kv_cache.cache_seqlens, - causal=True, - window_size=window_size, - ) + + if flash_attn is not None: + # Optimized Path (Linux/CUDA with FA3) + y = flash_attn.flash_attn_with_kvcache( + q, k_cache, v_cache, + k=k, v=v, + cache_seqlens=kv_cache.cache_seqlens, + causal=True, + window_size=window_size, + ) + else: + # Fallback Path (macOS/MPS or CPU) - Manual Cache Update + SDPA + positions = kv_cache.cache_seqlens + # Update cache manually for the batch + for b in range(B): + pos = positions[b].item() + k_cache[b, pos:pos+T] = k[b] + v_cache[b, pos:pos+T] = v[b] + + # Compute attention manually + y = torch.empty_like(q) + for b in range(B): + pos = positions[b].item() + + # Retrieve context & FORCE CAST to match query dtype (Fix for float vs bfloat16 crash) + k_curr = k_cache[b:b+1, :pos+T].to(dtype=q.dtype) + v_curr = v_cache[b:b+1, :pos+T].to(dtype=q.dtype) + q_curr = q[b:b+1] + + # Standard Attention (is_causal=False because context is explicit) + att_out = F.scaled_dot_product_attention(q_curr, k_curr, v_curr, is_causal=False) + y[b] = att_out[0] + # Advance position after last layer processes if self.layer_idx == kv_cache.n_layers - 1: kv_cache.advance(T)