keep varunneal kernel on H100, use community kernel for other supported cuda architectures

This commit is contained in:
svlandeg 2026-03-11 16:45:25 +01:00
parent a14c399576
commit 683b88c2b5

View File

@ -22,16 +22,15 @@ import torch.nn.functional as F
# =============================================================================
def _load_flash_attention_3():
"""Try to load Flash Attention 3."""
hf_kernel = 'kernels-community/flash-attn3'
if not torch.cuda.is_available():
return None
try:
cap = torch.cuda.get_device_capability()
# FA3 kernels are currently compiled for Hopper (sm90), Ada (sm89) and Ampere (sm80/sm86)
# Blackwell (sm100) needs SDPA fallback until FA3 is recompiled or FA4 is released
from kernels import get_kernel, has_kernel
supported = has_kernel(hf_kernel)
if not supported:
return None
# varunneal kernel obtains better results for H100
hf_kernel = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3"
from kernels import get_kernel
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
return get_kernel(hf_kernel).flash_attn_interface