From 433aacf77010263644a578e034188fc657be9395 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Mon, 2 Feb 2026 18:38:54 +0100 Subject: [PATCH] improve FA3 kernel loading --- nanochat/flash_attention.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 89ca42b..69b7144 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -18,22 +18,21 @@ import torch.nn.functional as F # ============================================================================= -# Detection: Try to load FA3 on Hopper+ GPUs +# Detection: Try to load FA3 on CUDA GPUs # ============================================================================= def _load_flash_attention_3(): - """Try to load Flash Attention 3 (requires Hopper GPU, sm90).""" + """Try to load Flash Attention 3.""" + hf_kernel = 'kernels-community/flash-attn3' if not torch.cuda.is_available(): return None try: - major, _ = torch.cuda.get_device_capability() - # FA3 kernels are compiled for Hopper (sm90) only - # Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled - if major != 9: + from kernels import get_kernel, has_kernel + supported = has_kernel(hf_kernel) + if not supported: return None import os os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" - from kernels import get_kernel - return get_kernel('varunneal/flash-attention-3').flash_attn_interface + return get_kernel(hf_kernel).flash_attn_interface except Exception: return None