Blackwell support

This commit is contained in:
Franci Penov 2026-01-30 14:52:37 -08:00
parent 4d8dbaf6e0
commit 2e45b7800a

View File

@ -2,7 +2,7 @@
Unified Flash Attention interface with automatic FA3/SDPA switching.
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
to PyTorch SDPA on non-Hopper GPUs, MPS, and CPU.
to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
Usage (drop-in replacement for FA3):
from nanochat.flash_attention import flash_attn
@ -21,12 +21,14 @@ import torch.nn.functional as F
# Detection: Try to load FA3 on Hopper+ GPUs
# =============================================================================
def _load_flash_attention_3():
"""Try to load Flash Attention 3 (requires Hopper+ GPU)."""
"""Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
if not torch.cuda.is_available():
return None
try:
major, _ = torch.cuda.get_device_capability()
if major < 9: # Hopper is sm90
major, minor = torch.cuda.get_device_capability()
# FA3 kernels are compiled for Hopper (sm90) only
# Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled
if major != 9:
return None
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"