mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-20 10:23:42 +00:00
feat: attempt fa3 load on sm < 9.0 (ampere/ada)
allow fa3 on non-hopper gpus if manually installed Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
38e4e0dd7b
commit
3e5fccdfa4
|
|
@ -35,8 +35,12 @@ try:
|
|||
# Flash Attention 3 uses NVIDIA Hopper-specific features like TMA (Tensor Memory Accelerator).
|
||||
# These are only physically available on GPUs with Compute Capability >= 9.0 (e.g. H100).
|
||||
# We explicitly check for this to prevent "No kernel image available" crashes on Ampere/Ada GPUs (RTX 30xx/40xx) etc.
|
||||
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9:
|
||||
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.get_device_capability()[0] >= 9:
|
||||
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
||||
else:
|
||||
# If the kernel image is not available, try installing the wheel manually from https://windreamer.github.io/flash-attention3-wheels/
|
||||
import flash_attn_interface as flash_attn
|
||||
except Exception:
|
||||
# Fallback to PyTorch SDPA on non-Hopper NVIDIA GPUs, Mac (MPS), or CPU.
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user