mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 13:15:21 +00:00
fix for unsupported cuda
This commit is contained in:
parent
ff833f8137
commit
d343c939c0
|
|
@ -28,12 +28,20 @@ def _load_flash_attention_3():
|
|||
major, _ = torch.cuda.get_device_capability()
|
||||
# FA3 kernels are currently compiled for Hopper (sm90), Ada (sm89) and Ampere (sm80/sm86)
|
||||
# Blackwell (sm100) needs SDPA fallback until FA3 is recompiled or FA4 is released
|
||||
# varunneal kernel obtains better results for H100/Hopper
|
||||
import os
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
||||
from kernels import get_kernel
|
||||
hf_kernel = "varunneal/flash-attention-3" if major == 9 else "kernels-community/flash-attn3"
|
||||
return get_kernel(hf_kernel).flash_attn_interface
|
||||
from kernels import get_kernel, has_kernel
|
||||
# The varunneal kernel obtains better results for H100/Hopper
|
||||
if major == 9:
|
||||
hf_kernel = "varunneal/flash-attention-3"
|
||||
return get_kernel(hf_kernel).flash_attn_interface
|
||||
else:
|
||||
hf_kernel = "kernels-community/flash-attn3"
|
||||
if has_kernel(hf_kernel):
|
||||
return get_kernel(hf_kernel).flash_attn_interface
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user