diff --git a/nanochat/gpt.py b/nanochat/gpt.py index cd88e46..a65e120 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -28,9 +28,7 @@ import os os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" # Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain. # Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal) - from kernels import get_kernel - try: flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface except Exception: @@ -101,7 +99,14 @@ class CausalSelfAttention(nn.Module): if flash_attn is not None: y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size) else: - y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + # Fallback Path (CPU/MPS): Needs Transpose to (B, H, T, D) for SDPA + y = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + is_causal=True + ) + y = y.transpose(1, 2) # Restore layout to (B, T, H, D) else: # Inference: use flash_attn_with_kvcache which handles cache management k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx) @@ -129,14 +134,15 @@ class CausalSelfAttention(nn.Module): for b in range(B): pos = positions[b].item() - # Retrieve context & FORCE CAST to match query dtype (Fix for float vs bfloat16 crash) - k_curr = k_cache[b:b+1, :pos+T].to(dtype=q.dtype) - v_curr = v_cache[b:b+1, :pos+T].to(dtype=q.dtype) - q_curr = q[b:b+1] + # Fetch history + cast to q.dtype (MPS fix) + Transpose (CPU fix) + k_curr = k_cache[b:b+1, :pos+T].to(dtype=q.dtype).transpose(1, 2) + v_curr = v_cache[b:b+1, :pos+T].to(dtype=q.dtype).transpose(1, 2) + q_curr = q[b:b+1].transpose(1, 2) # Standard Attention (is_causal=False because context is explicit) att_out = F.scaled_dot_product_attention(q_curr, k_curr, v_curr, is_causal=False) - y[b] = att_out[0] + # Transpose back to (B, T, H, D) and store + y[b] = att_out.transpose(1, 2)[0] # Advance position after last layer processes if self.layer_idx == kv_cache.n_layers - 1: @@ -440,4 +446,4 @@ class GPT(nn.Module): next_ids = torch.argmax(logits, dim=-1, keepdim=True) ids = torch.cat((ids, next_ids), dim=1) token = next_ids.item() - yield token + yield token \ No newline at end of file