improve FA3 kernel loading

This commit is contained in:
svlandeg 2026-02-02 18:38:54 +01:00
parent 72b9064f9d
commit 433aacf770

View File

@ -18,22 +18,21 @@ import torch.nn.functional as F
# =============================================================================
# Detection: Try to load FA3 on Hopper+ GPUs
# Detection: Try to load FA3 on CUDA GPUs
# =============================================================================
def _load_flash_attention_3():
"""Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
"""Try to load Flash Attention 3."""
hf_kernel = 'kernels-community/flash-attn3'
if not torch.cuda.is_available():
return None
try:
major, _ = torch.cuda.get_device_capability()
# FA3 kernels are compiled for Hopper (sm90) only
# Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled
if major != 9:
from kernels import get_kernel, has_kernel
supported = has_kernel(hf_kernel)
if not supported:
return None
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
from kernels import get_kernel
return get_kernel('varunneal/flash-attention-3').flash_attn_interface
return get_kernel(hf_kernel).flash_attn_interface
except Exception:
return None