diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index c79d124..67917b1 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -26,6 +26,8 @@ def _load_flash_attention_3(): if not torch.cuda.is_available(): return None try: + # FA3 kernels are currently compiled for Hopper (sm90) and Ampere (sm80/sm86) + # Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled from kernels import get_kernel, has_kernel supported = has_kernel(hf_kernel) if not supported: