Add Blackwell (SM100) GPU support via SDPA fallback (#475)

This commit is contained in:
Franci Penov 2026-01-31 19:42:58 -08:00 committed by GitHub
parent 0307997f9b
commit dc291c627f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,7 @@
Unified Flash Attention interface with automatic FA3/SDPA switching.
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
to PyTorch SDPA on non-Hopper GPUs, MPS, and CPU.
to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
Usage (drop-in replacement for FA3):
from nanochat.flash_attention import flash_attn
@ -21,12 +21,14 @@ import torch.nn.functional as F
# Detection: Try to load FA3 on Hopper+ GPUs
# =============================================================================
def _load_flash_attention_3():
"""Try to load Flash Attention 3 (requires Hopper+ GPU)."""
"""Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
if not torch.cuda.is_available():
return None
try:
major, _ = torch.cuda.get_device_capability()
if major < 9: # Hopper is sm90
# FA3 kernels are compiled for Hopper (sm90) only
# Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled
if major != 9:
return None
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"