mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-26 21:34:25 +00:00
feat: restrict FA3 loading to Hopper+ GPUs (SM90+) to fix crashes on consumer hardware
This commit is contained in:
parent
d7fccbab82
commit
97364273e2
|
|
@ -29,11 +29,17 @@ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
|||
# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain.
|
||||
# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal)
|
||||
from kernels import get_kernel
|
||||
|
||||
flash_attn = None
|
||||
try:
|
||||
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
||||
# Flash Attention 3 uses NVIDIA Hopper-specific features like TMA (Tensor Memory Accelerator).
|
||||
# These are only physically available on GPUs with Compute Capability >= 9.0 (e.g. H100).
|
||||
# We explicitly check for this to prevent "No kernel image available" crashes on Ampere/Ada GPUs (RTX 30xx/40xx) etc.
|
||||
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9:
|
||||
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
||||
except Exception:
|
||||
# Kernel loading failed (e.g. on Mac/MPS or CPU), fallback to SDPA
|
||||
flash_attn = None
|
||||
# Fallback to PyTorch SDPA on non-Hopper NVIDIA GPUs, Mac (MPS), or CPU.
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user