mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-02 16:49:50 +00:00
Add Blackwell (SM100) GPU support via SDPA fallback (#475)
This commit is contained in:
parent
0307997f9b
commit
dc291c627f
|
|
@ -2,7 +2,7 @@
|
|||
Unified Flash Attention interface with automatic FA3/SDPA switching.
|
||||
|
||||
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
|
||||
to PyTorch SDPA on non-Hopper GPUs, MPS, and CPU.
|
||||
to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
|
||||
|
||||
Usage (drop-in replacement for FA3):
|
||||
from nanochat.flash_attention import flash_attn
|
||||
|
|
@ -21,12 +21,14 @@ import torch.nn.functional as F
|
|||
# Detection: Try to load FA3 on Hopper+ GPUs
|
||||
# =============================================================================
|
||||
def _load_flash_attention_3():
|
||||
"""Try to load Flash Attention 3 (requires Hopper+ GPU)."""
|
||||
"""Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
try:
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
if major < 9: # Hopper is sm90
|
||||
# FA3 kernels are compiled for Hopper (sm90) only
|
||||
# Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled
|
||||
if major != 9:
|
||||
return None
|
||||
import os
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user