diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 89ca42b..c412810 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -16,6 +16,31 @@ Usage (drop-in replacement for FA3): import torch import torch.nn.functional as F +# Try to import flex_attention for efficient sliding window on non-Hopper GPUs +try: + from torch.nn.attention.flex_attention import ( + flex_attention as _flex_attn_fn, + create_block_mask as _create_block_mask, + ) + HAS_FLEX_ATTN = True +except ImportError: + HAS_FLEX_ATTN = False + +# Block mask cache keyed by (T, window, device_str) +_block_mask_cache: dict = {} + + +def _make_sliding_causal_block_mask(T, window, device): + """Create or retrieve cached BlockMask for causal sliding window attention.""" + key = (T, window, str(device)) + if key not in _block_mask_cache: + def mask_mod(b, h, q_idx, kv_idx): + return (q_idx >= kv_idx) & ((q_idx - kv_idx) <= window) + _block_mask_cache[key] = _create_block_mask( + mask_mod, B=None, H=None, Q_LEN=T, KV_LEN=T, device=device, + ) + return _block_mask_cache[key] + # ============================================================================= # Detection: Try to load FA3 on Hopper+ GPUs @@ -71,6 +96,12 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): if (window < 0 or window >= Tq) and Tq == Tk: return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa) + # Training: sliding window via flex_attention (block-sparse, no dense mask) + # Tq == Tk here means we're in the training forward pass (causal sliding window). + if Tq == Tk and HAS_FLEX_ATTN: + block_mask = _make_sliding_causal_block_mask(Tq, window, q.device) + return _flex_attn_fn(q, k, v, block_mask=block_mask, enable_gqa=enable_gqa) + # Single token generation if Tq == 1: if window >= 0 and window < Tk: diff --git a/scripts/base_train.py b/scripts/base_train.py index 24091b6..a25608a 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -31,7 +31,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine -from nanochat.flash_attention import HAS_FA3 +from nanochat.flash_attention import HAS_FA3, HAS_FLEX_ATTN from scripts.base_eval import evaluate_core print_banner() @@ -105,11 +105,14 @@ if HAS_FA3: print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.") else: print0("!" * 80) - print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback") + print0("WARNING: Flash Attention 3 not available, using PyTorch fallback") print0("WARNING: Training will be less efficient without FA3") if args.window_pattern != "L": - print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.") - print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.") + if HAS_FLEX_ATTN: + print0(f"✓ Using flex_attention for sliding window (window_pattern='{args.window_pattern}'). Block-sparse, efficient.") + else: + print0(f"WARNING: flex_attention not available for sliding window (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.") + print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.") print0("!" * 80) # ----------------------------------------------------------------------------- diff --git a/tests/test_flash_attention_new.py b/tests/test_flash_attention_new.py index 028496f..c8bd08f 100644 --- a/tests/test_flash_attention_new.py +++ b/tests/test_flash_attention_new.py @@ -62,6 +62,32 @@ def test_sdpa_attention_branches(): assert y3.shape == q3.shape +def test_flex_attention_sliding_window_matches_sdpa(monkeypatch): + """flex_attention sliding window path (Tq==Tk) must match explicit-mask SDPA.""" + if not fa.HAS_FLEX_ATTN: + pytest.skip("flex_attention not available") + + torch.manual_seed(0) + T, window = 16, 4 + q = torch.randn(1, 2, T, 8) + k = torch.randn(1, 2, T, 8) + v = torch.randn(1, 2, T, 8) + + # Force explicit-mask path + monkeypatch.setattr(fa, "HAS_FLEX_ATTN", False) + fa._block_mask_cache.clear() + y_sdpa = fa._sdpa_attention(q, k, v, window_size=(window, 0), enable_gqa=False) + + # Force flex_attention path + monkeypatch.setattr(fa, "HAS_FLEX_ATTN", True) + fa._block_mask_cache.clear() + y_flex = fa._sdpa_attention(q, k, v, window_size=(window, 0), enable_gqa=False) + + assert y_flex.shape == y_sdpa.shape + assert torch.allclose(y_flex, y_sdpa, atol=1e-5), \ + f"flex_attention and SDPA outputs differ: max_diff={( y_flex - y_sdpa).abs().max():.6f}" + + def test_public_flash_attn_paths(monkeypatch): q = torch.randn(1, 3, 2, 4) k = torch.randn(1, 3, 2, 4)