From dc291c627f69fb2fcc2298582946d5fa2d0384cd Mon Sep 17 00:00:00 2001 From: Franci Penov Date: Sat, 31 Jan 2026 19:42:58 -0800 Subject: [PATCH] Add Blackwell (SM100) GPU support via SDPA fallback (#475) --- nanochat/flash_attention.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 15411de..89ca42b 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -2,7 +2,7 @@ Unified Flash Attention interface with automatic FA3/SDPA switching. Exports `flash_attn` module that matches the FA3 API exactly, but falls back -to PyTorch SDPA on non-Hopper GPUs, MPS, and CPU. +to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU. Usage (drop-in replacement for FA3): from nanochat.flash_attention import flash_attn @@ -21,12 +21,14 @@ import torch.nn.functional as F # Detection: Try to load FA3 on Hopper+ GPUs # ============================================================================= def _load_flash_attention_3(): - """Try to load Flash Attention 3 (requires Hopper+ GPU).""" + """Try to load Flash Attention 3 (requires Hopper GPU, sm90).""" if not torch.cuda.is_available(): return None try: major, _ = torch.cuda.get_device_capability() - if major < 9: # Hopper is sm90 + # FA3 kernels are compiled for Hopper (sm90) only + # Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled + if major != 9: return None import os os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"