From d343c939c0550b1f6780051790adcd5f4d040f22 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 11 Mar 2026 16:52:50 +0100 Subject: [PATCH] fix for unsupported cuda --- nanochat/flash_attention.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 3950da5..7412b91 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -28,12 +28,20 @@ def _load_flash_attention_3(): major, _ = torch.cuda.get_device_capability() # FA3 kernels are currently compiled for Hopper (sm90), Ada (sm89) and Ampere (sm80/sm86) # Blackwell (sm100) needs SDPA fallback until FA3 is recompiled or FA4 is released - # varunneal kernel obtains better results for H100/Hopper import os os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" - from kernels import get_kernel - hf_kernel = "varunneal/flash-attention-3" if major == 9 else "kernels-community/flash-attn3" - return get_kernel(hf_kernel).flash_attn_interface + from kernels import get_kernel, has_kernel + # The varunneal kernel obtains better results for H100/Hopper + if major == 9: + hf_kernel = "varunneal/flash-attention-3" + return get_kernel(hf_kernel).flash_attn_interface + else: + hf_kernel = "kernels-community/flash-attn3" + if has_kernel(hf_kernel): + return get_kernel(hf_kernel).flash_attn_interface + else: + return None + except Exception: return None