Commit Graph

4 Commits

Author SHA1 Message Date
Matt Langston
75bd386b8e
add Flash Attention 2 as a middle tier between FA3 and SDPA
on sm80+ non-Hopper GPUs (Blackwell, Ada, Ampere) with the flash-attn package installed, FA2 kernels replace the SDPA fallback. priority is FA3 > FA2 > SDPA. measured 28% faster than SDPA on GB10, and makes sliding-window attention fast on Blackwell (where FA3 is unavailable). no effect on H100: USE_FA3 wins whenever available so runs/speedrun.sh on 8xH100 runs the same kernels as before. tests/test_attention_fallback.py::TestFA2VsSDPA compares FA2 and SDPA output on any sm80+ GPU with flash-attn installed.

context: https://github.com/karpathy/nanochat/discussions/710 (the writeup was produced from my dgx-spark branch at https://github.com/matt-langston/nanochat/tree/dgx-spark, which carries these two PRs plus a DGX-Spark-Bundle-specific speedrun script I kept separate)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-17 19:51:22 -07:00
Andrej Karpathy
1076f97059 delete autocast, an unnecessary thorn in my side, manage dtypes directly 2026-03-04 23:55:30 +00:00
Andrej Karpathy
3ba42e8135 Fix SDPA KV-cache decode to respect sliding window (#456)
SDPA fallback now respects sliding window during single-token KV-cache
decode by slicing K/V to the last (window + 1) tokens.

Also simplifies the mask building for chunk inference to properly apply
sliding window in that path as well.

Fixes #452

Co-Authored-By: Kartik Vashishta <kartikv776@gmail.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-30 17:32:12 +00:00
Andrej Karpathy
8203efa919 implement flash attention 3 fallback to pytorch sdpa by touching as few lines of code as possible in main files and keeping all implementation to a single file. add tests. add helpful warning messages for the user. 2026-01-16 17:37:51 +00:00