mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-10 01:40:17 +00:00
add Flash Attention 2 as a middle tier between FA3 and SDPA
on sm80+ non-Hopper GPUs (Blackwell, Ada, Ampere) with the flash-attn package installed, FA2 kernels replace the SDPA fallback. priority is FA3 > FA2 > SDPA. measured 28% faster than SDPA on GB10, and makes sliding-window attention fast on Blackwell (where FA3 is unavailable). no effect on H100: USE_FA3 wins whenever available so runs/speedrun.sh on 8xH100 runs the same kernels as before. tests/test_attention_fallback.py::TestFA2VsSDPA compares FA2 and SDPA output on any sm80+ GPU with flash-attn installed. context: https://github.com/karpathy/nanochat/discussions/710 (the writeup was produced from my dgx-spark branch at https://github.com/matt-langston/nanochat/tree/dgx-spark, which carries these two PRs plus a DGX-Spark-Bundle-specific speedrun script I kept separate) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
0aaca56805
commit
75bd386b8e
|
|
@ -1,8 +1,9 @@
|
|||
"""
|
||||
Unified Flash Attention interface with automatic FA3/SDPA switching.
|
||||
Unified Flash Attention interface with automatic FA3/FA2/SDPA switching.
|
||||
|
||||
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
|
||||
to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
|
||||
If the flash-attn package is installed, FA2 kernels are used instead of SDPA on sm80+ GPUs (FA3 wins if both are usable).
|
||||
|
||||
Usage (drop-in replacement for FA3):
|
||||
from nanochat.flash_attention import flash_attn
|
||||
|
|
@ -38,10 +39,25 @@ def _load_flash_attention_3():
|
|||
return None
|
||||
|
||||
|
||||
_fa3 = _load_flash_attention_3()
|
||||
HAS_FA3 = _fa3 is not None
|
||||
def _load_flash_attention_2():
|
||||
"""Try to load Flash Attention 2 (requires flash-attn package, sm80+)."""
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
try:
|
||||
import flash_attn.flash_attn_interface as fa2
|
||||
if hasattr(fa2, 'flash_attn_func') and hasattr(fa2, 'flash_attn_with_kvcache'):
|
||||
return fa2
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# Override for testing: set to 'fa3', 'sdpa', or None (auto)
|
||||
|
||||
_fa3 = _load_flash_attention_3()
|
||||
_fa2 = _load_flash_attention_2()
|
||||
HAS_FA3 = _fa3 is not None
|
||||
HAS_FA2 = _fa2 is not None
|
||||
|
||||
# Override for testing: set to 'fa3', 'fa2', 'sdpa', or None (auto)
|
||||
_override_impl = None
|
||||
|
||||
|
||||
|
|
@ -50,6 +66,8 @@ def _resolve_use_fa3():
|
|||
if _override_impl == 'fa3':
|
||||
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
|
||||
return True
|
||||
if _override_impl == 'fa2':
|
||||
return False
|
||||
if _override_impl == 'sdpa':
|
||||
return False
|
||||
if HAS_FA3:
|
||||
|
|
@ -63,6 +81,23 @@ def _resolve_use_fa3():
|
|||
USE_FA3 = _resolve_use_fa3()
|
||||
|
||||
|
||||
def _resolve_use_fa2():
|
||||
"""Decide once whether to use FA2, based on availability, override, and dtype. FA3 wins if both are usable."""
|
||||
if _override_impl == 'fa2':
|
||||
assert HAS_FA2, "Cannot override to FA2: flash-attn package not available"
|
||||
return True
|
||||
if _override_impl == 'sdpa' or USE_FA3:
|
||||
return False
|
||||
if HAS_FA2:
|
||||
# FA2 supports bf16 and fp16, not fp32
|
||||
from nanochat.common import COMPUTE_DTYPE
|
||||
if COMPUTE_DTYPE != torch.float32:
|
||||
return True
|
||||
return False
|
||||
|
||||
USE_FA2 = _resolve_use_fa2()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SDPA helpers
|
||||
# =============================================================================
|
||||
|
|
@ -118,6 +153,8 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
|||
"""
|
||||
if USE_FA3:
|
||||
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
||||
if USE_FA2:
|
||||
return _fa2.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
||||
|
||||
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
|
||||
q = q.transpose(1, 2)
|
||||
|
|
@ -151,6 +188,11 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
|
|||
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
|
||||
causal=causal, window_size=window_size
|
||||
)
|
||||
if USE_FA2:
|
||||
return _fa2.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
|
||||
causal=causal, window_size=window_size
|
||||
)
|
||||
|
||||
# SDPA fallback: manually manage KV cache
|
||||
B, T_new, H, D = q.shape
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
|||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.flash_attention import HAS_FA3
|
||||
from nanochat.flash_attention import HAS_FA3, HAS_FA2
|
||||
from scripts.base_eval import evaluate_core
|
||||
print_banner()
|
||||
|
||||
|
|
@ -100,10 +100,12 @@ use_dummy_wandb = args.run == "dummy" or not master_process
|
|||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
|
||||
|
||||
# Flash Attention status
|
||||
from nanochat.flash_attention import USE_FA3
|
||||
from nanochat.flash_attention import USE_FA3, USE_FA2
|
||||
using_fa3 = USE_FA3
|
||||
if using_fa3:
|
||||
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
|
||||
elif USE_FA2:
|
||||
print0("Using Flash Attention 2 (flash-attn package detected)")
|
||||
else:
|
||||
print0("!" * 80)
|
||||
if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16:
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from nanochat.tokenizer import get_token_bytes
|
|||
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
import torch.distributed as dist
|
||||
from nanochat.flash_attention import HAS_FA3
|
||||
from nanochat.flash_attention import HAS_FA3, HAS_FA2
|
||||
from nanochat.engine import Engine
|
||||
from scripts.chat_eval import run_chat_eval
|
||||
|
||||
|
|
@ -89,7 +89,7 @@ use_dummy_wandb = args.run == "dummy" or not master_process
|
|||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
|
||||
|
||||
# Flash Attention status
|
||||
if not HAS_FA3:
|
||||
if not HAS_FA3 and not HAS_FA2:
|
||||
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.")
|
||||
|
||||
# Load the model and tokenizer
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
"""
|
||||
Test Flash Attention unified interface - verify FA3 and SDPA produce identical results.
|
||||
Test Flash Attention unified interface - verify FA3, FA2, and SDPA produce identical results.
|
||||
|
||||
Run: python -m pytest tests/test_attention_fallback.py -v -s
|
||||
|
||||
Note on test structure:
|
||||
Tests are split into two classes due to dtype/device constraints:
|
||||
Tests are split into three classes due to dtype/device constraints:
|
||||
|
||||
1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs
|
||||
and verify they produce identical results. These require a Hopper GPU (FA3 only
|
||||
|
|
@ -12,18 +12,22 @@ Note on test structure:
|
|||
|
||||
2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run
|
||||
on any device (CUDA, CPU, MPS) with the appropriate dtype for that device.
|
||||
|
||||
3. TestFA2VsSDPA: Same as (1), but comparing FA2 (flash-attn package, sm80+) against SDPA.
|
||||
Runs on any sm80+ GPU and is the key correctness test on Blackwell where FA3 is unavailable.
|
||||
"""
|
||||
import torch
|
||||
import pytest
|
||||
import nanochat.flash_attention as fa_module
|
||||
from nanochat.flash_attention import flash_attn, HAS_FA3
|
||||
from nanochat.flash_attention import flash_attn, HAS_FA3, HAS_FA2
|
||||
from nanochat.engine import KVCache
|
||||
|
||||
|
||||
def set_impl(impl):
|
||||
"""Set the implementation override ('fa3', 'sdpa', or None for auto) and re-resolve USE_FA3."""
|
||||
"""Set the implementation override ('fa3', 'fa2', 'sdpa', or None for auto) and re-resolve USE_FA3/USE_FA2."""
|
||||
fa_module._override_impl = impl
|
||||
fa_module.USE_FA3 = fa_module._resolve_use_fa3()
|
||||
fa_module.USE_FA2 = fa_module._resolve_use_fa2()
|
||||
|
||||
|
||||
def run_both_impls(fn):
|
||||
|
|
@ -36,6 +40,16 @@ def run_both_impls(fn):
|
|||
return out_fa3, out_sdpa
|
||||
|
||||
|
||||
def run_fa2_and_sdpa(fn):
|
||||
"""Run a function with both FA2 and SDPA, return both outputs."""
|
||||
set_impl('fa2')
|
||||
out_fa2 = fn()
|
||||
set_impl('sdpa')
|
||||
out_sdpa = fn()
|
||||
set_impl(None) # reset
|
||||
return out_fa2, out_sdpa
|
||||
|
||||
|
||||
def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2):
|
||||
"""Assert two tensors are close, with helpful error message."""
|
||||
max_diff = (t1 - t2).abs().max().item()
|
||||
|
|
@ -334,6 +348,206 @@ class TestSDPAOnly:
|
|||
set_impl(None)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FA2 vs SDPA comparison tests (require flash-attn package)
|
||||
# =============================================================================
|
||||
@pytest.mark.skipif(not HAS_FA2, reason="FA2 required (flash-attn package)")
|
||||
class TestFA2VsSDPA:
|
||||
"""Compare FA2 and SDPA produce identical results. Requires flash-attn package (sm80+)."""
|
||||
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.bfloat16
|
||||
|
||||
def test_basic_causal(self):
|
||||
"""Basic causal attention."""
|
||||
B, T, H, D = 2, 64, 4, 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
|
||||
y_fa2, y_sdpa = run_fa2_and_sdpa(run)
|
||||
max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "basic_causal")
|
||||
print(f"basic_causal: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_full_context(self):
|
||||
"""Full context (window_size=-1)."""
|
||||
B, T, H, D = 2, 128, 4, 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1))
|
||||
|
||||
y_fa2, y_sdpa = run_fa2_and_sdpa(run)
|
||||
max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "full_context")
|
||||
print(f"full_context: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_sliding_window(self):
|
||||
"""Sliding window attention (key test: FA2 has native sliding kernels, SDPA uses an explicit mask - both must agree so GB10 can run SSSL instead of L)."""
|
||||
B, T, H, D = 2, 128, 4, 32
|
||||
window = 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(window, 0))
|
||||
|
||||
y_fa2, y_sdpa = run_fa2_and_sdpa(run)
|
||||
max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "sliding_window")
|
||||
print(f"sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_sliding_window_realistic(self):
|
||||
"""Realistic sliding window matching depth=24 config (n_embd=1536, n_head=12, head_dim=128, seq_len=2048, SSSL short_window=768)."""
|
||||
B, T, H, D = 2, 256, 12, 128 # scaled-down seq_len for speed
|
||||
window = 768
|
||||
n_kv_heads = 4 # GQA
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(window, 0))
|
||||
|
||||
y_fa2, y_sdpa = run_fa2_and_sdpa(run)
|
||||
max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "sliding_window_realistic")
|
||||
print(f"sliding_window_realistic: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_gqa(self):
|
||||
"""Group Query Attention (fewer KV heads than Q heads)."""
|
||||
B, T, D = 2, 64, 32
|
||||
n_heads = 8
|
||||
n_kv_heads = 2
|
||||
|
||||
q = torch.randn(B, T, n_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
|
||||
y_fa2, y_sdpa = run_fa2_and_sdpa(run)
|
||||
max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "gqa")
|
||||
print(f"gqa: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_kvcache_prefill(self):
|
||||
"""Test prefill (inserting multiple tokens into empty cache)."""
|
||||
B, T_max, H, D = 2, 64, 4, 32
|
||||
T_prefill = 16
|
||||
|
||||
q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
cache_seqlens = torch.zeros(B, dtype=torch.int32, device=self.DEVICE)
|
||||
return flash_attn.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k=k, v=v,
|
||||
cache_seqlens=cache_seqlens,
|
||||
causal=True, window_size=(T_max, 0)
|
||||
)
|
||||
|
||||
y_fa2, y_sdpa = run_fa2_and_sdpa(run)
|
||||
max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "prefill")
|
||||
print(f"prefill: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_kvcache_single_token(self):
|
||||
"""Test single token generation (cache already has content)."""
|
||||
B, T_max, H, D = 2, 64, 4, 32
|
||||
T_prefill = 16
|
||||
|
||||
k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_cache[:, :T_prefill, :, :] = k_init
|
||||
v_cache[:, :T_prefill, :, :] = v_init
|
||||
cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE)
|
||||
return flash_attn.flash_attn_with_kvcache(
|
||||
q_single, k_cache, v_cache, k=k_single, v=v_single,
|
||||
cache_seqlens=cache_seqlens,
|
||||
causal=True, window_size=(T_max, 0)
|
||||
)
|
||||
|
||||
y_fa2, y_sdpa = run_fa2_and_sdpa(run)
|
||||
max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "single_token")
|
||||
print(f"single_token: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_kvcache_single_token_sliding_window(self):
|
||||
"""Test single token decode with sliding window smaller than cache size (SSSL inference path)."""
|
||||
B, T_max, H, D = 2, 64, 4, 32
|
||||
T_prefill = 32 # Enough tokens to exceed window
|
||||
window = 8 # Window SMALLER than cache size
|
||||
|
||||
k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_cache[:, :T_prefill, :, :] = k_init
|
||||
v_cache[:, :T_prefill, :, :] = v_init
|
||||
cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE)
|
||||
return flash_attn.flash_attn_with_kvcache(
|
||||
q_single, k_cache, v_cache, k=k_single, v=v_single,
|
||||
cache_seqlens=cache_seqlens,
|
||||
causal=True, window_size=(window, 0) # window=8 < Tk=33
|
||||
)
|
||||
|
||||
y_fa2, y_sdpa = run_fa2_and_sdpa(run)
|
||||
max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "single_token_sliding_window")
|
||||
print(f"single_token_sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_backward_gradients_match(self):
|
||||
"""Verify gradients are similar between FA2 and SDPA."""
|
||||
B, T, H, D = 2, 32, 4, 16
|
||||
|
||||
q_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
q = q_data.clone().requires_grad_(True)
|
||||
k = k_data.clone().requires_grad_(True)
|
||||
v = v_data.clone().requires_grad_(True)
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
return y.detach(), q.grad.detach(), k.grad.detach(), v.grad.detach()
|
||||
|
||||
set_impl('fa2')
|
||||
y_fa2, q_grad_fa2, k_grad_fa2, v_grad_fa2 = run()
|
||||
set_impl('sdpa')
|
||||
y_sdpa, q_grad_sdpa, k_grad_sdpa, v_grad_sdpa = run()
|
||||
set_impl(None)
|
||||
|
||||
max_diff, mean_diff = assert_close(y_fa2, y_sdpa, "backward_output")
|
||||
print(f"backward_output: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
max_diff, mean_diff = assert_close(q_grad_fa2, q_grad_sdpa, "q_grad", atol=0.05, rtol=0.05)
|
||||
print(f"q_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
max_diff, mean_diff = assert_close(k_grad_fa2, k_grad_sdpa, "k_grad", atol=0.05, rtol=0.05)
|
||||
print(f"k_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
max_diff, mean_diff = assert_close(v_grad_fa2, v_grad_sdpa, "v_grad", atol=0.05, rtol=0.05)
|
||||
print(f"v_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Override mechanism tests
|
||||
# =============================================================================
|
||||
|
|
@ -358,6 +572,19 @@ class TestOverrideMechanism:
|
|||
set_impl(None)
|
||||
assert fa_module.USE_FA3 == HAS_FA3
|
||||
|
||||
@pytest.mark.skipif(not HAS_FA2, reason="FA2 required")
|
||||
def test_override_fa2(self):
|
||||
"""Test that override='fa2' uses FA2."""
|
||||
set_impl('fa2')
|
||||
assert fa_module.USE_FA2 == True
|
||||
set_impl(None)
|
||||
|
||||
@pytest.mark.skipif(not (HAS_FA2 and not HAS_FA3), reason="FA2 auto-selection only applies when FA3 is unavailable")
|
||||
def test_override_auto_fa2(self):
|
||||
"""Test that override=None picks FA2 when FA3 is unavailable."""
|
||||
set_impl(None)
|
||||
assert fa_module.USE_FA2 == True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"PyTorch version: {torch.__version__}")
|
||||
|
|
@ -367,6 +594,7 @@ if __name__ == "__main__":
|
|||
major, minor = torch.cuda.get_device_capability()
|
||||
print(f"Compute capability: {major}.{minor}")
|
||||
print(f"HAS_FA3: {HAS_FA3}")
|
||||
print(f"HAS_FA2: {HAS_FA2}")
|
||||
print()
|
||||
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user