also update comments

This commit is contained in:
svlandeg 2026-02-02 19:41:42 +01:00
parent 433aacf770
commit 535f664cc2
4 changed files with 8 additions and 8 deletions

View File

@ -2,7 +2,7 @@
Unified Flash Attention interface with automatic FA3/SDPA switching.
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
to PyTorch SDPA on incompatible CUDA GPUs, MPS, and CPU.
Usage (drop-in replacement for FA3):
from nanochat.flash_attention import flash_attn

View File

@ -22,7 +22,7 @@ import torch.nn.functional as F
from nanochat.common import get_dist_info, print0
from nanochat.optim import MuonAdamW, DistMuonAdamW
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
# Our custom Flash Attention module that automatically uses FA3 when compatible and SDPA fallback otherwise
from nanochat.flash_attention import flash_attn
@dataclass
@ -93,7 +93,7 @@ class CausalSelfAttention(nn.Module):
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k) # QK norm
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
# Flash Attention (FA3 or SDPA fallback)
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
if kv_cache is None:
# Training: causal attention with optional sliding window

View File

@ -96,7 +96,7 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat",
# Flash Attention status
if HAS_FA3:
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
print0("✓ Using Flash Attention 3: efficient, new and awesome.")
else:
print0("!" * 80)
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")

View File

@ -7,8 +7,8 @@ Note on test structure:
Tests are split into two classes due to dtype/device constraints:
1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs
and verify they produce identical results. These require a Hopper GPU (FA3 only
works on sm90+) and use bfloat16 (FA3 doesn't support float32).
and verify they produce identical results. These require a compatible GPU (FA3 only
works on sm80 and sm90) and use bfloat16 (FA3 doesn't support float32).
2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run
on any device (CUDA, CPU, MPS) with the appropriate dtype for that device.
@ -45,11 +45,11 @@ def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2):
# =============================================================================
# FA3 vs SDPA comparison tests (require Hopper GPU)
# FA3 vs SDPA comparison tests
# =============================================================================
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required to compare implementations")
class TestFA3VsSDPA:
"""Compare FA3 and SDPA produce identical results. Requires Hopper GPU."""
"""Compare FA3 and SDPA produce identical results."""
DEVICE = "cuda"
DTYPE = torch.bfloat16