From c9c01ffe04175f19df15688feea4bd12d6bace0a Mon Sep 17 00:00:00 2001 From: hasan Date: Wed, 14 Jan 2026 01:10:29 +0100 Subject: [PATCH 1/4] fix: add Flash Attention 3 fallback for MPS/CPU inference --- nanochat/gpt.py | 49 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 81ccb0c..38ba153 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -28,8 +28,8 @@ import os os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" # Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain. # Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal) -from kernels import get_kernel -flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface + +flash_attn = None @dataclass class GPTConfig: @@ -92,17 +92,46 @@ class CausalSelfAttention(nn.Module): # window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context if kv_cache is None: # Training: causal attention with optional sliding window - y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size) + if flash_attn is not None: + y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size) + else: + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) else: # Inference: use flash_attn_with_kvcache which handles cache management k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx) - y = flash_attn.flash_attn_with_kvcache( - q, k_cache, v_cache, - k=k, v=v, - cache_seqlens=kv_cache.cache_seqlens, - causal=True, - window_size=window_size, - ) + + if flash_attn is not None: + # Optimized Path (Linux/CUDA with FA3) + y = flash_attn.flash_attn_with_kvcache( + q, k_cache, v_cache, + k=k, v=v, + cache_seqlens=kv_cache.cache_seqlens, + causal=True, + window_size=window_size, + ) + else: + # Fallback Path (macOS/MPS or CPU) - Manual Cache Update + SDPA + positions = kv_cache.cache_seqlens + # Update cache manually for the batch + for b in range(B): + pos = positions[b].item() + k_cache[b, pos:pos+T] = k[b] + v_cache[b, pos:pos+T] = v[b] + + # Compute attention manually + y = torch.empty_like(q) + for b in range(B): + pos = positions[b].item() + + # Retrieve context & FORCE CAST to match query dtype (Fix for float vs bfloat16 crash) + k_curr = k_cache[b:b+1, :pos+T].to(dtype=q.dtype) + v_curr = v_cache[b:b+1, :pos+T].to(dtype=q.dtype) + q_curr = q[b:b+1] + + # Standard Attention (is_causal=False because context is explicit) + att_out = F.scaled_dot_product_attention(q_curr, k_curr, v_curr, is_causal=False) + y[b] = att_out[0] + # Advance position after last layer processes if self.layer_idx == kv_cache.n_layers - 1: kv_cache.advance(T) From 68e66be05c698dce4d6c48a60fc94ab8bf516466 Mon Sep 17 00:00:00 2001 From: hasan Date: Wed, 14 Jan 2026 15:23:55 +0100 Subject: [PATCH 2/4] fix: wrap FA3 import in try-except block to support both CUDA and MPS --- nanochat/gpt.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 38ba153..cd88e46 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -29,7 +29,13 @@ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" # Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain. # Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal) -flash_attn = None +from kernels import get_kernel + +try: + flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface +except Exception: + # Kernel loading failed (e.g. on Mac/MPS or CPU), fallback to SDPA + flash_attn = None @dataclass class GPTConfig: From d7fccbab82620533360ff1cb4bffb48e95b617d8 Mon Sep 17 00:00:00 2001 From: hasan Date: Wed, 14 Jan 2026 21:42:20 +0100 Subject: [PATCH 3/4] fix: enforce (B, H, T, D) layout for SDPA fallback to support CPU strictness --- nanochat/gpt.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index cd88e46..a65e120 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -28,9 +28,7 @@ import os os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" # Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain. # Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal) - from kernels import get_kernel - try: flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface except Exception: @@ -101,7 +99,14 @@ class CausalSelfAttention(nn.Module): if flash_attn is not None: y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size) else: - y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + # Fallback Path (CPU/MPS): Needs Transpose to (B, H, T, D) for SDPA + y = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + is_causal=True + ) + y = y.transpose(1, 2) # Restore layout to (B, T, H, D) else: # Inference: use flash_attn_with_kvcache which handles cache management k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx) @@ -129,14 +134,15 @@ class CausalSelfAttention(nn.Module): for b in range(B): pos = positions[b].item() - # Retrieve context & FORCE CAST to match query dtype (Fix for float vs bfloat16 crash) - k_curr = k_cache[b:b+1, :pos+T].to(dtype=q.dtype) - v_curr = v_cache[b:b+1, :pos+T].to(dtype=q.dtype) - q_curr = q[b:b+1] + # Fetch history + cast to q.dtype (MPS fix) + Transpose (CPU fix) + k_curr = k_cache[b:b+1, :pos+T].to(dtype=q.dtype).transpose(1, 2) + v_curr = v_cache[b:b+1, :pos+T].to(dtype=q.dtype).transpose(1, 2) + q_curr = q[b:b+1].transpose(1, 2) # Standard Attention (is_causal=False because context is explicit) att_out = F.scaled_dot_product_attention(q_curr, k_curr, v_curr, is_causal=False) - y[b] = att_out[0] + # Transpose back to (B, T, H, D) and store + y[b] = att_out.transpose(1, 2)[0] # Advance position after last layer processes if self.layer_idx == kv_cache.n_layers - 1: @@ -440,4 +446,4 @@ class GPT(nn.Module): next_ids = torch.argmax(logits, dim=-1, keepdim=True) ids = torch.cat((ids, next_ids), dim=1) token = next_ids.item() - yield token + yield token \ No newline at end of file From 97364273e272fe7cdd70fc7b99b4900f28a9c6bf Mon Sep 17 00:00:00 2001 From: hasan Date: Wed, 14 Jan 2026 22:14:42 +0100 Subject: [PATCH 4/4] feat: restrict FA3 loading to Hopper+ GPUs (SM90+) to fix crashes on consumer hardware --- nanochat/gpt.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index a65e120..d214054 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -29,11 +29,17 @@ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" # Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain. # Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal) from kernels import get_kernel + +flash_attn = None try: - flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface + # Flash Attention 3 uses NVIDIA Hopper-specific features like TMA (Tensor Memory Accelerator). + # These are only physically available on GPUs with Compute Capability >= 9.0 (e.g. H100). + # We explicitly check for this to prevent "No kernel image available" crashes on Ampere/Ada GPUs (RTX 30xx/40xx) etc. + if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9: + flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface except Exception: - # Kernel loading failed (e.g. on Mac/MPS or CPU), fallback to SDPA - flash_attn = None + # Fallback to PyTorch SDPA on non-Hopper NVIDIA GPUs, Mac (MPS), or CPU. + pass @dataclass class GPTConfig: