feat: restrict FA3 loading to Hopper+ GPUs (SM90+) to fix crashes on consumer hardware

This commit is contained in:
hasan 2026-01-14 22:14:42 +01:00
parent d7fccbab82
commit 97364273e2

View File

@ -29,11 +29,17 @@ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain.
# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal)
from kernels import get_kernel
flash_attn = None
try:
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
# Flash Attention 3 uses NVIDIA Hopper-specific features like TMA (Tensor Memory Accelerator).
# These are only physically available on GPUs with Compute Capability >= 9.0 (e.g. H100).
# We explicitly check for this to prevent "No kernel image available" crashes on Ampere/Ada GPUs (RTX 30xx/40xx) etc.
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9:
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
except Exception:
# Kernel loading failed (e.g. on Mac/MPS or CPU), fallback to SDPA
flash_attn = None
# Fallback to PyTorch SDPA on non-Hopper NVIDIA GPUs, Mac (MPS), or CPU.
pass
@dataclass
class GPTConfig: