From ff833f8137426629305b0e24fe7c90cd1bc034ea Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 11 Mar 2026 16:49:09 +0100 Subject: [PATCH] small fix --- nanochat/flash_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 098afc3..3950da5 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -29,10 +29,10 @@ def _load_flash_attention_3(): # FA3 kernels are currently compiled for Hopper (sm90), Ada (sm89) and Ampere (sm80/sm86) # Blackwell (sm100) needs SDPA fallback until FA3 is recompiled or FA4 is released # varunneal kernel obtains better results for H100/Hopper - hf_kernel = "varunneal/flash-attention-3" if major == 9 else "kernels-community/flash-attn3" - from kernels import get_kernel import os os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" + from kernels import get_kernel + hf_kernel = "varunneal/flash-attention-3" if major == 9 else "kernels-community/flash-attn3" return get_kernel(hf_kernel).flash_attn_interface except Exception: return None