diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 89ca42b..69b7144 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -18,22 +18,21 @@ import torch.nn.functional as F # ============================================================================= -# Detection: Try to load FA3 on Hopper+ GPUs +# Detection: Try to load FA3 on CUDA GPUs # ============================================================================= def _load_flash_attention_3(): - """Try to load Flash Attention 3 (requires Hopper GPU, sm90).""" + """Try to load Flash Attention 3.""" + hf_kernel = 'kernels-community/flash-attn3' if not torch.cuda.is_available(): return None try: - major, _ = torch.cuda.get_device_capability() - # FA3 kernels are compiled for Hopper (sm90) only - # Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled - if major != 9: + from kernels import get_kernel, has_kernel + supported = has_kernel(hf_kernel) + if not supported: return None import os os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" - from kernels import get_kernel - return get_kernel('varunneal/flash-attention-3').flash_attn_interface + return get_kernel(hf_kernel).flash_attn_interface except Exception: return None