diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 6b5dae0..098afc3 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -25,11 +25,11 @@ def _load_flash_attention_3(): if not torch.cuda.is_available(): return None try: - cap = torch.cuda.get_device_capability() + major, _ = torch.cuda.get_device_capability() # FA3 kernels are currently compiled for Hopper (sm90), Ada (sm89) and Ampere (sm80/sm86) # Blackwell (sm100) needs SDPA fallback until FA3 is recompiled or FA4 is released - # varunneal kernel obtains better results for H100 - hf_kernel = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + # varunneal kernel obtains better results for H100/Hopper + hf_kernel = "varunneal/flash-attention-3" if major == 9 else "kernels-community/flash-attn3" from kernels import get_kernel import os os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"