diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 2b1f41a..6b5dae0 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -22,16 +22,15 @@ import torch.nn.functional as F # ============================================================================= def _load_flash_attention_3(): """Try to load Flash Attention 3.""" - hf_kernel = 'kernels-community/flash-attn3' if not torch.cuda.is_available(): return None try: + cap = torch.cuda.get_device_capability() # FA3 kernels are currently compiled for Hopper (sm90), Ada (sm89) and Ampere (sm80/sm86) # Blackwell (sm100) needs SDPA fallback until FA3 is recompiled or FA4 is released - from kernels import get_kernel, has_kernel - supported = has_kernel(hf_kernel) - if not supported: - return None + # varunneal kernel obtains better results for H100 + hf_kernel = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + from kernels import get_kernel import os os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" return get_kernel(hf_kernel).flash_attn_interface