diff --git a/nanochat/gpt.py b/nanochat/gpt.py index d214054..5f00bc2 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -35,8 +35,12 @@ try: # Flash Attention 3 uses NVIDIA Hopper-specific features like TMA (Tensor Memory Accelerator). # These are only physically available on GPUs with Compute Capability >= 9.0 (e.g. H100). # We explicitly check for this to prevent "No kernel image available" crashes on Ampere/Ada GPUs (RTX 30xx/40xx) etc. - if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9: - flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface + if torch.cuda.is_available(): + if torch.cuda.get_device_capability()[0] >= 9: + flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface + else: + # If the kernel image is not available, try installing the wheel manually from https://windreamer.github.io/flash-attention3-wheels/ + import flash_attn_interface as flash_attn except Exception: # Fallback to PyTorch SDPA on non-Hopper NVIDIA GPUs, Mac (MPS), or CPU. pass