mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-20 20:03:19 +00:00
fallback to flex_attention when FA3 is not available
This commit is contained in:
parent
29b76c5695
commit
c48aa05531
|
|
@ -16,6 +16,31 @@ Usage (drop-in replacement for FA3):
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Try to import flex_attention for efficient sliding window on non-Hopper GPUs
|
||||
try:
|
||||
from torch.nn.attention.flex_attention import (
|
||||
flex_attention as _flex_attn_fn,
|
||||
create_block_mask as _create_block_mask,
|
||||
)
|
||||
HAS_FLEX_ATTN = True
|
||||
except ImportError:
|
||||
HAS_FLEX_ATTN = False
|
||||
|
||||
# Block mask cache keyed by (T, window, device_str)
|
||||
_block_mask_cache: dict = {}
|
||||
|
||||
|
||||
def _make_sliding_causal_block_mask(T, window, device):
|
||||
"""Create or retrieve cached BlockMask for causal sliding window attention."""
|
||||
key = (T, window, str(device))
|
||||
if key not in _block_mask_cache:
|
||||
def mask_mod(b, h, q_idx, kv_idx):
|
||||
return (q_idx >= kv_idx) & ((q_idx - kv_idx) <= window)
|
||||
_block_mask_cache[key] = _create_block_mask(
|
||||
mask_mod, B=None, H=None, Q_LEN=T, KV_LEN=T, device=device,
|
||||
)
|
||||
return _block_mask_cache[key]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Detection: Try to load FA3 on Hopper+ GPUs
|
||||
|
|
@ -71,6 +96,12 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
|||
if (window < 0 or window >= Tq) and Tq == Tk:
|
||||
return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
||||
|
||||
# Training: sliding window via flex_attention (block-sparse, no dense mask)
|
||||
# Tq == Tk here means we're in the training forward pass (causal sliding window).
|
||||
if Tq == Tk and HAS_FLEX_ATTN:
|
||||
block_mask = _make_sliding_causal_block_mask(Tq, window, q.device)
|
||||
return _flex_attn_fn(q, k, v, block_mask=block_mask, enable_gqa=enable_gqa)
|
||||
|
||||
# Single token generation
|
||||
if Tq == 1:
|
||||
if window >= 0 and window < Tk:
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
|||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.flash_attention import HAS_FA3
|
||||
from nanochat.flash_attention import HAS_FA3, HAS_FLEX_ATTN
|
||||
from scripts.base_eval import evaluate_core
|
||||
print_banner()
|
||||
|
||||
|
|
@ -105,11 +105,14 @@ if HAS_FA3:
|
|||
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
|
||||
else:
|
||||
print0("!" * 80)
|
||||
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
|
||||
print0("WARNING: Flash Attention 3 not available, using PyTorch fallback")
|
||||
print0("WARNING: Training will be less efficient without FA3")
|
||||
if args.window_pattern != "L":
|
||||
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
|
||||
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
|
||||
if HAS_FLEX_ATTN:
|
||||
print0(f"✓ Using flex_attention for sliding window (window_pattern='{args.window_pattern}'). Block-sparse, efficient.")
|
||||
else:
|
||||
print0(f"WARNING: flex_attention not available for sliding window (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
|
||||
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
|
||||
print0("!" * 80)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -62,6 +62,32 @@ def test_sdpa_attention_branches():
|
|||
assert y3.shape == q3.shape
|
||||
|
||||
|
||||
def test_flex_attention_sliding_window_matches_sdpa(monkeypatch):
|
||||
"""flex_attention sliding window path (Tq==Tk) must match explicit-mask SDPA."""
|
||||
if not fa.HAS_FLEX_ATTN:
|
||||
pytest.skip("flex_attention not available")
|
||||
|
||||
torch.manual_seed(0)
|
||||
T, window = 16, 4
|
||||
q = torch.randn(1, 2, T, 8)
|
||||
k = torch.randn(1, 2, T, 8)
|
||||
v = torch.randn(1, 2, T, 8)
|
||||
|
||||
# Force explicit-mask path
|
||||
monkeypatch.setattr(fa, "HAS_FLEX_ATTN", False)
|
||||
fa._block_mask_cache.clear()
|
||||
y_sdpa = fa._sdpa_attention(q, k, v, window_size=(window, 0), enable_gqa=False)
|
||||
|
||||
# Force flex_attention path
|
||||
monkeypatch.setattr(fa, "HAS_FLEX_ATTN", True)
|
||||
fa._block_mask_cache.clear()
|
||||
y_flex = fa._sdpa_attention(q, k, v, window_size=(window, 0), enable_gqa=False)
|
||||
|
||||
assert y_flex.shape == y_sdpa.shape
|
||||
assert torch.allclose(y_flex, y_sdpa, atol=1e-5), \
|
||||
f"flex_attention and SDPA outputs differ: max_diff={( y_flex - y_sdpa).abs().max():.6f}"
|
||||
|
||||
|
||||
def test_public_flash_attn_paths(monkeypatch):
|
||||
q = torch.randn(1, 3, 2, 4)
|
||||
k = torch.randn(1, 3, 2, 4)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user