diff --git a/nanochat/gpt.py b/nanochat/gpt.py index a65e120..d214054 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -29,11 +29,17 @@ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" # Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain. # Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal) from kernels import get_kernel + +flash_attn = None try: - flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface + # Flash Attention 3 uses NVIDIA Hopper-specific features like TMA (Tensor Memory Accelerator). + # These are only physically available on GPUs with Compute Capability >= 9.0 (e.g. H100). + # We explicitly check for this to prevent "No kernel image available" crashes on Ampere/Ada GPUs (RTX 30xx/40xx) etc. + if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9: + flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface except Exception: - # Kernel loading failed (e.g. on Mac/MPS or CPU), fallback to SDPA - flash_attn = None + # Fallback to PyTorch SDPA on non-Hopper NVIDIA GPUs, Mac (MPS), or CPU. + pass @dataclass class GPTConfig: