This commit is contained in:
Yixin Liu 2026-03-26 06:26:41 +08:00 committed by GitHub
commit 12676d918d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 68 additions and 16 deletions

View File

@ -1,8 +1,10 @@
"""
Unified Flash Attention interface with automatic FA3/SDPA switching.
Unified Flash Attention interface with automatic FA4/FA3/SDPA switching.
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
Exports `flash_attn` module that matches the FA3 API exactly, but uses:
- FA4 (flash-attn-4) on Blackwell (sm100) and Hopper (sm90) GPUs
- FA3 (kernels hub) on Hopper (sm90) if FA4 is not installed
- PyTorch SDPA fallback on all other hardware (MPS, CPU, older GPUs)
Usage (drop-in replacement for FA3):
from nanochat.flash_attention import flash_attn
@ -18,8 +20,22 @@ import torch.nn.functional as F
# =============================================================================
# Detection: Try to load FA3 on Hopper+ GPUs
# Detection: Try to load FA4 (Hopper + Blackwell), then FA3 (Hopper only)
# =============================================================================
def _load_flash_attention_4():
"""Try to load Flash Attention 4 (supports Hopper sm90 and Blackwell sm100)."""
if not torch.cuda.is_available():
return None
try:
major, _ = torch.cuda.get_device_capability()
if major not in (9, 10): # FA4 supports sm90 (Hopper) and sm100 (Blackwell)
return None
from flash_attn.cute import flash_attn_func as fa4_func
return fa4_func
except Exception:
return None
def _load_flash_attention_3():
"""Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
if not torch.cuda.is_available():
@ -38,15 +54,41 @@ def _load_flash_attention_3():
return None
_fa3 = _load_flash_attention_3()
_fa4_func_raw = _load_flash_attention_4()
HAS_FA4 = _fa4_func_raw is not None
# Wrap FA4 to prevent torch.compile/dynamo from tracing into CuTeDSL internals
if HAS_FA4:
@torch.compiler.disable
def _fa4_func(q, k, v, causal=False, window_size=(None, None)):
return _fa4_func_raw(q, k, v, causal=causal, window_size=window_size)
else:
_fa4_func = None
_fa3 = _load_flash_attention_3() if not HAS_FA4 else None
HAS_FA3 = _fa3 is not None
# Override for testing: set to 'fa3', 'sdpa', or None (auto)
HAS_FLASH = HAS_FA4 or HAS_FA3
_impl_name = "FA4" if HAS_FA4 else ("FA3" if HAS_FA3 else "SDPA")
# Override for testing: set to 'fa4', 'fa3', 'sdpa', or None (auto)
_override_impl = None
def _resolve_use_fa4():
"""Decide once whether to use FA4."""
if _override_impl == 'fa4':
assert HAS_FA4, "Cannot override to FA4: not available"
return True
if _override_impl in ('fa3', 'sdpa'):
return False
return HAS_FA4
def _resolve_use_fa3():
"""Decide once whether to use FA3, based on availability, override, and dtype."""
if USE_FA4:
return False
if _override_impl == 'fa3':
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
return True
@ -60,6 +102,7 @@ def _resolve_use_fa3():
return False
return False
USE_FA4 = _resolve_use_fa4()
USE_FA3 = _resolve_use_fa3()
@ -116,6 +159,13 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
Returns:
Output tensor of shape (B, T, H, D)
"""
if USE_FA4:
# FA4 uses None instead of -1 for unlimited window
ws = (None if window_size[0] < 0 else window_size[0],
None if window_size[1] < 0 else window_size[1])
result = _fa4_func(q, k, v, causal=causal, window_size=ws)
return result[0] if isinstance(result, tuple) else result
if USE_FA3:
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)

View File

@ -32,7 +32,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.flash_attention import HAS_FA3
from nanochat.flash_attention import HAS_FA3, HAS_FLASH, _impl_name
from scripts.base_eval import evaluate_core
print_banner()
@ -100,17 +100,17 @@ use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
# Flash Attention status
from nanochat.flash_attention import USE_FA3
using_fa3 = USE_FA3
if using_fa3:
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
from nanochat.flash_attention import USE_FA3, USE_FA4
using_flash = USE_FA4 or USE_FA3
if using_flash:
print0(f"✓ Using Flash Attention ({_impl_name}), efficient, new and awesome.")
else:
print0("!" * 80)
if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16:
print0(f"WARNING: Flash Attention 3 only supports bf16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback")
else:
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
print0("WARNING: Training will be less efficient without FA3")
print0("WARNING: Flash Attention not available, using PyTorch SDPA fallback")
print0("WARNING: Training will be less efficient without Flash Attention")
if args.window_pattern != "L":
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")

View File

@ -21,7 +21,7 @@ from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
from nanochat.loss_eval import evaluate_bpb
import torch.distributed as dist
from nanochat.flash_attention import HAS_FA3
from nanochat.flash_attention import HAS_FA3, HAS_FLASH, _impl_name
from nanochat.engine import Engine
from scripts.chat_eval import run_chat_eval
@ -89,8 +89,10 @@ use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
# Flash Attention status
if not HAS_FA3:
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.")
if HAS_FLASH:
print0(f"✓ Using Flash Attention ({_impl_name})")
else:
print0("WARNING: Flash Attention not available, using PyTorch SDPA fallback. Training will be less efficient.")
# Load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)