fix for unsupported cuda

This commit is contained in:
svlandeg 2026-03-11 16:52:50 +01:00
parent ff833f8137
commit d343c939c0

View File

@ -28,12 +28,20 @@ def _load_flash_attention_3():
major, _ = torch.cuda.get_device_capability()
# FA3 kernels are currently compiled for Hopper (sm90), Ada (sm89) and Ampere (sm80/sm86)
# Blackwell (sm100) needs SDPA fallback until FA3 is recompiled or FA4 is released
# varunneal kernel obtains better results for H100/Hopper
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
from kernels import get_kernel
hf_kernel = "varunneal/flash-attention-3" if major == 9 else "kernels-community/flash-attn3"
return get_kernel(hf_kernel).flash_attn_interface
from kernels import get_kernel, has_kernel
# The varunneal kernel obtains better results for H100/Hopper
if major == 9:
hf_kernel = "varunneal/flash-attention-3"
return get_kernel(hf_kernel).flash_attn_interface
else:
hf_kernel = "kernels-community/flash-attn3"
if has_kernel(hf_kernel):
return get_kernel(hf_kernel).flash_attn_interface
else:
return None
except Exception:
return None