Moved default window pattern to fa3 code and added explanation

This commit is contained in:
Daniel Dudek 2026-02-06 18:24:05 +01:00
parent fdaebf22cc
commit 9caf6690a1
2 changed files with 5 additions and 2 deletions

View File

@ -168,6 +168,9 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
def default_window_pattern():
# FA3 has native support for window pattern otherwise there is none and it's slow -> fallback to just L
return "SSSL" if _use_fa3() else "L"
# =============================================================================
# Export: flash_attn module interface (drop-in replacement for FA3)

View File

@ -31,7 +31,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.flash_attention import HAS_FA3
from nanochat.flash_attention import HAS_FA3, default_window_pattern
from scripts.base_eval import evaluate_core
print_banner()
@ -50,7 +50,7 @@ parser.add_argument("--depth", type=int, default=20, help="depth of the Transfor
parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--window-pattern", type=str, default="SSSL" if HAS_FA3 else "L", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
parser.add_argument("--window-pattern", type=str, default=default_window_pattern(), help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
# Training horizon (only one used, in order of precedence)
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")