mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 05:35:19 +00:00
improve FA3 kernel loading
This commit is contained in:
parent
72b9064f9d
commit
433aacf770
|
|
@ -18,22 +18,21 @@ import torch.nn.functional as F
|
|||
|
||||
|
||||
# =============================================================================
|
||||
# Detection: Try to load FA3 on Hopper+ GPUs
|
||||
# Detection: Try to load FA3 on CUDA GPUs
|
||||
# =============================================================================
|
||||
def _load_flash_attention_3():
|
||||
"""Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
|
||||
"""Try to load Flash Attention 3."""
|
||||
hf_kernel = 'kernels-community/flash-attn3'
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
try:
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
# FA3 kernels are compiled for Hopper (sm90) only
|
||||
# Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled
|
||||
if major != 9:
|
||||
from kernels import get_kernel, has_kernel
|
||||
supported = has_kernel(hf_kernel)
|
||||
if not supported:
|
||||
return None
|
||||
import os
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
||||
from kernels import get_kernel
|
||||
return get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
||||
return get_kernel(hf_kernel).flash_attn_interface
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user