diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 89ca42b..db9a05b 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -168,6 +168,9 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N return y_sdpa.transpose(1, 2) # back to (B, T, H, D) +def default_window_pattern(): + # FA3 has native support for window pattern otherwise there is none and it's slow -> fallback to just L + return "SSSL" if _use_fa3() else "L" # ============================================================================= # Export: flash_attn module interface (drop-in replacement for FA3) diff --git a/scripts/base_train.py b/scripts/base_train.py index 3f526be..953a2e5 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -31,7 +31,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine -from nanochat.flash_attention import HAS_FA3 +from nanochat.flash_attention import HAS_FA3, default_window_pattern from scripts.base_eval import evaluate_core print_banner() @@ -50,7 +50,7 @@ parser.add_argument("--depth", type=int, default=20, help="depth of the Transfor parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio") parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") -parser.add_argument("--window-pattern", type=str, default="SSSL" if HAS_FA3 else "L", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") +parser.add_argument("--window-pattern", type=str, default=default_window_pattern(), help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") # Training horizon (only one used, in order of precedence) parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")