diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 67917b1b..c7e2d3ec 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -26,8 +26,8 @@ def _load_flash_attention_3(): if not torch.cuda.is_available(): return None try: - # FA3 kernels are currently compiled for Hopper (sm90) and Ampere (sm80/sm86) - # Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled + # FA3 kernels are currently compiled for Hopper (sm90), Ada (sm89) and Ampere (sm80/sm86) + # Blackwell (sm100) needs SDPA fallback until FA3 is recompiled or FA4 is released from kernels import get_kernel, has_kernel supported = has_kernel(hf_kernel) if not supported: