fix: wrap FA3 import in try-except block to support both CUDA and MPS

This commit is contained in:
hasan 2026-01-14 15:23:55 +01:00
parent c9c01ffe04
commit 68e66be05c

View File

@ -29,7 +29,13 @@ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain.
# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal)
flash_attn = None
from kernels import get_kernel
try:
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
except Exception:
# Kernel loading failed (e.g. on Mac/MPS or CPU), fallback to SDPA
flash_attn = None
@dataclass
class GPTConfig: