feat: attempt fa3 load on sm < 9.0 (ampere/ada)

allow fa3 on non-hopper gpus if manually installed

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
hasso 2026-01-16 11:18:12 +01:00 committed by GitHub
parent 38e4e0dd7b
commit 3e5fccdfa4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -35,8 +35,12 @@ try:
# Flash Attention 3 uses NVIDIA Hopper-specific features like TMA (Tensor Memory Accelerator).
# These are only physically available on GPUs with Compute Capability >= 9.0 (e.g. H100).
# We explicitly check for this to prevent "No kernel image available" crashes on Ampere/Ada GPUs (RTX 30xx/40xx) etc.
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9:
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
if torch.cuda.is_available():
if torch.cuda.get_device_capability()[0] >= 9:
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
else:
# If the kernel image is not available, try installing the wheel manually from https://windreamer.github.io/flash-attention3-wheels/
import flash_attn_interface as flash_attn
except Exception:
# Fallback to PyTorch SDPA on non-Hopper NVIDIA GPUs, Mac (MPS), or CPU.
pass