diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 81ccb0c..5f00bc2 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -29,7 +29,21 @@ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" # Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain. # Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal) from kernels import get_kernel -flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface + +flash_attn = None +try: + # Flash Attention 3 uses NVIDIA Hopper-specific features like TMA (Tensor Memory Accelerator). + # These are only physically available on GPUs with Compute Capability >= 9.0 (e.g. H100). + # We explicitly check for this to prevent "No kernel image available" crashes on Ampere/Ada GPUs (RTX 30xx/40xx) etc. + if torch.cuda.is_available(): + if torch.cuda.get_device_capability()[0] >= 9: + flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface + else: + # If the kernel image is not available, try installing the wheel manually from https://windreamer.github.io/flash-attention3-wheels/ + import flash_attn_interface as flash_attn +except Exception: + # Fallback to PyTorch SDPA on non-Hopper NVIDIA GPUs, Mac (MPS), or CPU. + pass @dataclass class GPTConfig: @@ -92,17 +106,54 @@ class CausalSelfAttention(nn.Module): # window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context if kv_cache is None: # Training: causal attention with optional sliding window - y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size) + if flash_attn is not None: + y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size) + else: + # Fallback Path (CPU/MPS): Needs Transpose to (B, H, T, D) for SDPA + y = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + is_causal=True + ) + y = y.transpose(1, 2) # Restore layout to (B, T, H, D) else: # Inference: use flash_attn_with_kvcache which handles cache management k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx) - y = flash_attn.flash_attn_with_kvcache( - q, k_cache, v_cache, - k=k, v=v, - cache_seqlens=kv_cache.cache_seqlens, - causal=True, - window_size=window_size, - ) + + if flash_attn is not None: + # Optimized Path (Linux/CUDA with FA3) + y = flash_attn.flash_attn_with_kvcache( + q, k_cache, v_cache, + k=k, v=v, + cache_seqlens=kv_cache.cache_seqlens, + causal=True, + window_size=window_size, + ) + else: + # Fallback Path (macOS/MPS or CPU) - Manual Cache Update + SDPA + positions = kv_cache.cache_seqlens + # Update cache manually for the batch + for b in range(B): + pos = positions[b].item() + k_cache[b, pos:pos+T] = k[b] + v_cache[b, pos:pos+T] = v[b] + + # Compute attention manually + y = torch.empty_like(q) + for b in range(B): + pos = positions[b].item() + + # Fetch history + cast to q.dtype (MPS fix) + Transpose (CPU fix) + k_curr = k_cache[b:b+1, :pos+T].to(dtype=q.dtype).transpose(1, 2) + v_curr = v_cache[b:b+1, :pos+T].to(dtype=q.dtype).transpose(1, 2) + q_curr = q[b:b+1].transpose(1, 2) + + # Standard Attention (is_causal=False because context is explicit) + att_out = F.scaled_dot_product_attention(q_curr, k_curr, v_curr, is_causal=False) + # Transpose back to (B, T, H, D) and store + y[b] = att_out.transpose(1, 2)[0] + # Advance position after last layer processes if self.layer_idx == kv_cache.n_layers - 1: kv_cache.advance(T) @@ -405,4 +456,4 @@ class GPT(nn.Module): next_ids = torch.argmax(logits, dim=-1, keepdim=True) ids = torch.cat((ids, next_ids), dim=1) token = next_ids.item() - yield token + yield token \ No newline at end of file