mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 13:15:21 +00:00
Merge 8fc2829db5 into 7808dc7159
This commit is contained in:
commit
12676d918d
|
|
@ -1,8 +1,10 @@
|
|||
"""
|
||||
Unified Flash Attention interface with automatic FA3/SDPA switching.
|
||||
Unified Flash Attention interface with automatic FA4/FA3/SDPA switching.
|
||||
|
||||
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
|
||||
to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
|
||||
Exports `flash_attn` module that matches the FA3 API exactly, but uses:
|
||||
- FA4 (flash-attn-4) on Blackwell (sm100) and Hopper (sm90) GPUs
|
||||
- FA3 (kernels hub) on Hopper (sm90) if FA4 is not installed
|
||||
- PyTorch SDPA fallback on all other hardware (MPS, CPU, older GPUs)
|
||||
|
||||
Usage (drop-in replacement for FA3):
|
||||
from nanochat.flash_attention import flash_attn
|
||||
|
|
@ -18,8 +20,22 @@ import torch.nn.functional as F
|
|||
|
||||
|
||||
# =============================================================================
|
||||
# Detection: Try to load FA3 on Hopper+ GPUs
|
||||
# Detection: Try to load FA4 (Hopper + Blackwell), then FA3 (Hopper only)
|
||||
# =============================================================================
|
||||
def _load_flash_attention_4():
|
||||
"""Try to load Flash Attention 4 (supports Hopper sm90 and Blackwell sm100)."""
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
try:
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
if major not in (9, 10): # FA4 supports sm90 (Hopper) and sm100 (Blackwell)
|
||||
return None
|
||||
from flash_attn.cute import flash_attn_func as fa4_func
|
||||
return fa4_func
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _load_flash_attention_3():
|
||||
"""Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
|
||||
if not torch.cuda.is_available():
|
||||
|
|
@ -38,15 +54,41 @@ def _load_flash_attention_3():
|
|||
return None
|
||||
|
||||
|
||||
_fa3 = _load_flash_attention_3()
|
||||
_fa4_func_raw = _load_flash_attention_4()
|
||||
HAS_FA4 = _fa4_func_raw is not None
|
||||
|
||||
# Wrap FA4 to prevent torch.compile/dynamo from tracing into CuTeDSL internals
|
||||
if HAS_FA4:
|
||||
@torch.compiler.disable
|
||||
def _fa4_func(q, k, v, causal=False, window_size=(None, None)):
|
||||
return _fa4_func_raw(q, k, v, causal=causal, window_size=window_size)
|
||||
else:
|
||||
_fa4_func = None
|
||||
|
||||
_fa3 = _load_flash_attention_3() if not HAS_FA4 else None
|
||||
HAS_FA3 = _fa3 is not None
|
||||
|
||||
# Override for testing: set to 'fa3', 'sdpa', or None (auto)
|
||||
HAS_FLASH = HAS_FA4 or HAS_FA3
|
||||
_impl_name = "FA4" if HAS_FA4 else ("FA3" if HAS_FA3 else "SDPA")
|
||||
|
||||
# Override for testing: set to 'fa4', 'fa3', 'sdpa', or None (auto)
|
||||
_override_impl = None
|
||||
|
||||
|
||||
def _resolve_use_fa4():
|
||||
"""Decide once whether to use FA4."""
|
||||
if _override_impl == 'fa4':
|
||||
assert HAS_FA4, "Cannot override to FA4: not available"
|
||||
return True
|
||||
if _override_impl in ('fa3', 'sdpa'):
|
||||
return False
|
||||
return HAS_FA4
|
||||
|
||||
|
||||
def _resolve_use_fa3():
|
||||
"""Decide once whether to use FA3, based on availability, override, and dtype."""
|
||||
if USE_FA4:
|
||||
return False
|
||||
if _override_impl == 'fa3':
|
||||
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
|
||||
return True
|
||||
|
|
@ -60,6 +102,7 @@ def _resolve_use_fa3():
|
|||
return False
|
||||
return False
|
||||
|
||||
USE_FA4 = _resolve_use_fa4()
|
||||
USE_FA3 = _resolve_use_fa3()
|
||||
|
||||
|
||||
|
|
@ -116,6 +159,13 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
|||
Returns:
|
||||
Output tensor of shape (B, T, H, D)
|
||||
"""
|
||||
if USE_FA4:
|
||||
# FA4 uses None instead of -1 for unlimited window
|
||||
ws = (None if window_size[0] < 0 else window_size[0],
|
||||
None if window_size[1] < 0 else window_size[1])
|
||||
result = _fa4_func(q, k, v, causal=causal, window_size=ws)
|
||||
return result[0] if isinstance(result, tuple) else result
|
||||
|
||||
if USE_FA3:
|
||||
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
|||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.flash_attention import HAS_FA3
|
||||
from nanochat.flash_attention import HAS_FA3, HAS_FLASH, _impl_name
|
||||
from scripts.base_eval import evaluate_core
|
||||
print_banner()
|
||||
|
||||
|
|
@ -100,17 +100,17 @@ use_dummy_wandb = args.run == "dummy" or not master_process
|
|||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
|
||||
|
||||
# Flash Attention status
|
||||
from nanochat.flash_attention import USE_FA3
|
||||
using_fa3 = USE_FA3
|
||||
if using_fa3:
|
||||
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
|
||||
from nanochat.flash_attention import USE_FA3, USE_FA4
|
||||
using_flash = USE_FA4 or USE_FA3
|
||||
if using_flash:
|
||||
print0(f"✓ Using Flash Attention ({_impl_name}), efficient, new and awesome.")
|
||||
else:
|
||||
print0("!" * 80)
|
||||
if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16:
|
||||
print0(f"WARNING: Flash Attention 3 only supports bf16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback")
|
||||
else:
|
||||
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
|
||||
print0("WARNING: Training will be less efficient without FA3")
|
||||
print0("WARNING: Flash Attention not available, using PyTorch SDPA fallback")
|
||||
print0("WARNING: Training will be less efficient without Flash Attention")
|
||||
if args.window_pattern != "L":
|
||||
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
|
||||
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from nanochat.tokenizer import get_token_bytes
|
|||
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
import torch.distributed as dist
|
||||
from nanochat.flash_attention import HAS_FA3
|
||||
from nanochat.flash_attention import HAS_FA3, HAS_FLASH, _impl_name
|
||||
from nanochat.engine import Engine
|
||||
from scripts.chat_eval import run_chat_eval
|
||||
|
||||
|
|
@ -89,8 +89,10 @@ use_dummy_wandb = args.run == "dummy" or not master_process
|
|||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
|
||||
|
||||
# Flash Attention status
|
||||
if not HAS_FA3:
|
||||
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.")
|
||||
if HAS_FLASH:
|
||||
print0(f"✓ Using Flash Attention ({_impl_name})")
|
||||
else:
|
||||
print0("WARNING: Flash Attention not available, using PyTorch SDPA fallback. Training will be less efficient.")
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user