diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index af2aee3..e7bb234 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -1,8 +1,10 @@ """ -Unified Flash Attention interface with automatic FA3/SDPA switching. +Unified Flash Attention interface with automatic FA4/FA3/SDPA switching. -Exports `flash_attn` module that matches the FA3 API exactly, but falls back -to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU. +Exports `flash_attn` module that matches the FA3 API exactly, but uses: + - FA4 (flash-attn-4) on Blackwell (sm100) and Hopper (sm90) GPUs + - FA3 (kernels hub) on Hopper (sm90) if FA4 is not installed + - PyTorch SDPA fallback on all other hardware (MPS, CPU, older GPUs) Usage (drop-in replacement for FA3): from nanochat.flash_attention import flash_attn @@ -18,8 +20,22 @@ import torch.nn.functional as F # ============================================================================= -# Detection: Try to load FA3 on Hopper+ GPUs +# Detection: Try to load FA4 (Hopper + Blackwell), then FA3 (Hopper only) # ============================================================================= +def _load_flash_attention_4(): + """Try to load Flash Attention 4 (supports Hopper sm90 and Blackwell sm100).""" + if not torch.cuda.is_available(): + return None + try: + major, _ = torch.cuda.get_device_capability() + if major not in (9, 10): # FA4 supports sm90 (Hopper) and sm100 (Blackwell) + return None + from flash_attn.cute import flash_attn_func as fa4_func + return fa4_func + except Exception: + return None + + def _load_flash_attention_3(): """Try to load Flash Attention 3 (requires Hopper GPU, sm90).""" if not torch.cuda.is_available(): @@ -38,15 +54,41 @@ def _load_flash_attention_3(): return None -_fa3 = _load_flash_attention_3() +_fa4_func_raw = _load_flash_attention_4() +HAS_FA4 = _fa4_func_raw is not None + +# Wrap FA4 to prevent torch.compile/dynamo from tracing into CuTeDSL internals +if HAS_FA4: + @torch.compiler.disable + def _fa4_func(q, k, v, causal=False, window_size=(None, None)): + return _fa4_func_raw(q, k, v, causal=causal, window_size=window_size) +else: + _fa4_func = None + +_fa3 = _load_flash_attention_3() if not HAS_FA4 else None HAS_FA3 = _fa3 is not None -# Override for testing: set to 'fa3', 'sdpa', or None (auto) +HAS_FLASH = HAS_FA4 or HAS_FA3 +_impl_name = "FA4" if HAS_FA4 else ("FA3" if HAS_FA3 else "SDPA") + +# Override for testing: set to 'fa4', 'fa3', 'sdpa', or None (auto) _override_impl = None +def _resolve_use_fa4(): + """Decide once whether to use FA4.""" + if _override_impl == 'fa4': + assert HAS_FA4, "Cannot override to FA4: not available" + return True + if _override_impl in ('fa3', 'sdpa'): + return False + return HAS_FA4 + + def _resolve_use_fa3(): """Decide once whether to use FA3, based on availability, override, and dtype.""" + if USE_FA4: + return False if _override_impl == 'fa3': assert HAS_FA3, "Cannot override to FA3: not available on this hardware" return True @@ -60,6 +102,7 @@ def _resolve_use_fa3(): return False return False +USE_FA4 = _resolve_use_fa4() USE_FA3 = _resolve_use_fa3() @@ -116,6 +159,13 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): Returns: Output tensor of shape (B, T, H, D) """ + if USE_FA4: + # FA4 uses None instead of -1 for unlimited window + ws = (None if window_size[0] < 0 else window_size[0], + None if window_size[1] < 0 else window_size[1]) + result = _fa4_func(q, k, v, causal=causal, window_size=ws) + return result[0] if isinstance(result, tuple) else result + if USE_FA3: return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size) diff --git a/scripts/base_train.py b/scripts/base_train.py index a161c47..a9a93b7 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -32,7 +32,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine -from nanochat.flash_attention import HAS_FA3 +from nanochat.flash_attention import HAS_FA3, HAS_FLASH, _impl_name from scripts.base_eval import evaluate_core print_banner() @@ -100,17 +100,17 @@ use_dummy_wandb = args.run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config) # Flash Attention status -from nanochat.flash_attention import USE_FA3 -using_fa3 = USE_FA3 -if using_fa3: - print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.") +from nanochat.flash_attention import USE_FA3, USE_FA4 +using_flash = USE_FA4 or USE_FA3 +if using_flash: + print0(f"✓ Using Flash Attention ({_impl_name}), efficient, new and awesome.") else: print0("!" * 80) if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16: print0(f"WARNING: Flash Attention 3 only supports bf16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback") else: - print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback") - print0("WARNING: Training will be less efficient without FA3") + print0("WARNING: Flash Attention not available, using PyTorch SDPA fallback") + print0("WARNING: Training will be less efficient without Flash Attention") if args.window_pattern != "L": print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.") print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.") diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index b46dd81..91e29b3 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -21,7 +21,7 @@ from nanochat.tokenizer import get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state from nanochat.loss_eval import evaluate_bpb import torch.distributed as dist -from nanochat.flash_attention import HAS_FA3 +from nanochat.flash_attention import HAS_FA3, HAS_FLASH, _impl_name from nanochat.engine import Engine from scripts.chat_eval import run_chat_eval @@ -89,8 +89,10 @@ use_dummy_wandb = args.run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config) # Flash Attention status -if not HAS_FA3: - print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.") +if HAS_FLASH: + print0(f"✓ Using Flash Attention ({_impl_name})") +else: + print0("WARNING: Flash Attention not available, using PyTorch SDPA fallback. Training will be less efficient.") # Load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)