fallback to flex_attention when FA3 is not available

This commit is contained in:
Your Name 2026-03-04 12:21:13 -08:00
parent 29b76c5695
commit c48aa05531
3 changed files with 64 additions and 4 deletions

View File

@ -16,6 +16,31 @@ Usage (drop-in replacement for FA3):
import torch
import torch.nn.functional as F
# Try to import flex_attention for efficient sliding window on non-Hopper GPUs
try:
from torch.nn.attention.flex_attention import (
flex_attention as _flex_attn_fn,
create_block_mask as _create_block_mask,
)
HAS_FLEX_ATTN = True
except ImportError:
HAS_FLEX_ATTN = False
# Block mask cache keyed by (T, window, device_str)
_block_mask_cache: dict = {}
def _make_sliding_causal_block_mask(T, window, device):
"""Create or retrieve cached BlockMask for causal sliding window attention."""
key = (T, window, str(device))
if key not in _block_mask_cache:
def mask_mod(b, h, q_idx, kv_idx):
return (q_idx >= kv_idx) & ((q_idx - kv_idx) <= window)
_block_mask_cache[key] = _create_block_mask(
mask_mod, B=None, H=None, Q_LEN=T, KV_LEN=T, device=device,
)
return _block_mask_cache[key]
# =============================================================================
# Detection: Try to load FA3 on Hopper+ GPUs
@ -71,6 +96,12 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
if (window < 0 or window >= Tq) and Tq == Tk:
return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
# Training: sliding window via flex_attention (block-sparse, no dense mask)
# Tq == Tk here means we're in the training forward pass (causal sliding window).
if Tq == Tk and HAS_FLEX_ATTN:
block_mask = _make_sliding_causal_block_mask(Tq, window, q.device)
return _flex_attn_fn(q, k, v, block_mask=block_mask, enable_gqa=enable_gqa)
# Single token generation
if Tq == 1:
if window >= 0 and window < Tk:

View File

@ -31,7 +31,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.flash_attention import HAS_FA3
from nanochat.flash_attention import HAS_FA3, HAS_FLEX_ATTN
from scripts.base_eval import evaluate_core
print_banner()
@ -105,11 +105,14 @@ if HAS_FA3:
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
else:
print0("!" * 80)
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
print0("WARNING: Flash Attention 3 not available, using PyTorch fallback")
print0("WARNING: Training will be less efficient without FA3")
if args.window_pattern != "L":
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
if HAS_FLEX_ATTN:
print0(f"✓ Using flex_attention for sliding window (window_pattern='{args.window_pattern}'). Block-sparse, efficient.")
else:
print0(f"WARNING: flex_attention not available for sliding window (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
print0("!" * 80)
# -----------------------------------------------------------------------------

View File

@ -62,6 +62,32 @@ def test_sdpa_attention_branches():
assert y3.shape == q3.shape
def test_flex_attention_sliding_window_matches_sdpa(monkeypatch):
"""flex_attention sliding window path (Tq==Tk) must match explicit-mask SDPA."""
if not fa.HAS_FLEX_ATTN:
pytest.skip("flex_attention not available")
torch.manual_seed(0)
T, window = 16, 4
q = torch.randn(1, 2, T, 8)
k = torch.randn(1, 2, T, 8)
v = torch.randn(1, 2, T, 8)
# Force explicit-mask path
monkeypatch.setattr(fa, "HAS_FLEX_ATTN", False)
fa._block_mask_cache.clear()
y_sdpa = fa._sdpa_attention(q, k, v, window_size=(window, 0), enable_gqa=False)
# Force flex_attention path
monkeypatch.setattr(fa, "HAS_FLEX_ATTN", True)
fa._block_mask_cache.clear()
y_flex = fa._sdpa_attention(q, k, v, window_size=(window, 0), enable_gqa=False)
assert y_flex.shape == y_sdpa.shape
assert torch.allclose(y_flex, y_sdpa, atol=1e-5), \
f"flex_attention and SDPA outputs differ: max_diff={( y_flex - y_sdpa).abs().max():.6f}"
def test_public_flash_attn_paths(monkeypatch):
q = torch.randn(1, 3, 2, 4)
k = torch.randn(1, 3, 2, 4)