From f8ca0b5c21ae591858c18a76b88f570143d9e7ae Mon Sep 17 00:00:00 2001 From: EFE AYDIN Date: Fri, 15 May 2026 23:16:01 +0300 Subject: [PATCH 1/5] Add chunk mask caching and update attention functions for fixing some performance choking --- nanochat/flash_attention.py | 74 +++++++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 20 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index af2aee32..5189a4d0 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -63,10 +63,36 @@ def _resolve_use_fa3(): USE_FA3 = _resolve_use_fa3() +# ============================================================================= +# Mask cache for chunked inference (Tq != Tk, Tq > 1) +# ============================================================================= +_MASK_CACHE: dict = {} +_MASK_CACHE_MAX = 32 + + +def _get_chunk_mask(device, Tq, Tk, window): + """Cached causal (+sliding window) mask for chunk inference.""" + key = (device.type, device.index, Tq, Tk, window) + m = _MASK_CACHE.get(key) + if m is not None: + return m + + row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1) + col_idx = torch.arange(Tk, device=device).unsqueeze(0) + m = col_idx <= row_idx + if window >= 0 and window < Tk: + m = m & ((row_idx - col_idx) <= window) + + if len(_MASK_CACHE) >= _MASK_CACHE_MAX: + _MASK_CACHE.clear() + _MASK_CACHE[key] = m + return m + + # ============================================================================= # SDPA helpers # ============================================================================= -def _sdpa_attention(q, k, v, window_size, enable_gqa): +def _sdpa_attention(q, k, v, window_size, enable_gqa, causal=True): """ SDPA attention with sliding window support. q, k, v are (B, H, T, D) format. @@ -77,7 +103,7 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): # Full context, same length if (window < 0 or window >= Tq) and Tq == Tk: - return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa) + return F.scaled_dot_product_attention(q, k, v, is_causal=causal, enable_gqa=enable_gqa) # Single token generation if Tq == 1: @@ -88,19 +114,11 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): v = v[:, :, start:, :] return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) - # Need explicit mask for sliding window/chunk inference - device = q.device - # For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask - row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1) - col_idx = torch.arange(Tk, device=device).unsqueeze(0) - mask = col_idx <= row_idx - - # sliding window (left) - if window >= 0 and window < Tk: - mask = mask & ((row_idx - col_idx) <= window) - + # Chunk inference (Tq > 1, Tq != Tk): use cached explicit bool mask. + mask = _get_chunk_mask(q.device, Tq, Tk, window) return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) + # ============================================================================= # Public API: Same interface as FA3 # ============================================================================= @@ -124,7 +142,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): k = k.transpose(1, 2) v = v.transpose(1, 2) enable_gqa = q.size(1) != k.size(1) - y = _sdpa_attention(q, k, v, window_size, enable_gqa) + y = _sdpa_attention(q, k, v, window_size, enable_gqa, causal=causal) return y.transpose(1, 2) # back to (B, T, H, D) @@ -139,7 +157,8 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N q: Queries, shape (B, T_new, H, D) k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D) k, v: New keys/values to insert, shape (B, T_new, H_kv, D) - cache_seqlens: Current position in cache, shape (B,) int32 + cache_seqlens: Current position in cache. Either an int (fast path, no + GPU->CPU sync) or a tensor of shape (B,) int32 (FA3-compatible). causal: Whether to use causal masking window_size: (left, right) sliding window. -1 means unlimited. @@ -154,17 +173,32 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N # SDPA fallback: manually manage KV cache B, T_new, H, D = q.shape - pos = cache_seqlens[0].item() # assume uniform position across batch + + # Avoid GPU->CPU sync if caller passes a Python int. + if isinstance(cache_seqlens, int): + pos = cache_seqlens + elif isinstance(cache_seqlens, torch.Tensor): + pos = int(cache_seqlens[0].item()) # assume uniform position across batch + else: + pos = int(cache_seqlens) # Insert new k, v into cache (in-place, matching FA3 behavior) if k is not None and v is not None: k_cache[:, pos:pos+T_new, :, :] = k v_cache[:, pos:pos+T_new, :, :] = v - # Get full cache up to current position + new tokens end_pos = pos + T_new - k_full = k_cache[:, :end_pos, :, :] - v_full = v_cache[:, :end_pos, :, :] + + # Sliding-window single-token decode: trim cache slice early so SDPA sees + # only the window instead of the full prefix. + window = window_size[0] + if T_new == 1 and 0 <= window < end_pos: + start = max(0, end_pos - (window + 1)) + k_full = k_cache[:, start:end_pos, :, :] + v_full = v_cache[:, start:end_pos, :, :] + else: + k_full = k_cache[:, :end_pos, :, :] + v_full = v_cache[:, :end_pos, :, :] # Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D) q_sdpa = q.transpose(1, 2) @@ -172,7 +206,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N v_sdpa = v_full.transpose(1, 2) enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) - y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) + y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa, causal=causal) return y_sdpa.transpose(1, 2) # back to (B, T, H, D) From 990a26332c8466575f2b9c378eb6ad3aaf405293 Mon Sep 17 00:00:00 2001 From: EFE AYDIN Date: Fri, 15 May 2026 23:21:38 +0300 Subject: [PATCH 2/5] Replace lru_cache with instance-level cache for tokens --- nanochat/tokenizer.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/nanochat/tokenizer.py b/nanochat/tokenizer.py index a2146c2e..6ee0603e 100644 --- a/nanochat/tokenizer.py +++ b/nanochat/tokenizer.py @@ -8,7 +8,6 @@ Two implementations are available: import os import copy -from functools import lru_cache SPECIAL_TOKENS = [ # every document begins with the Beginning of Sequence (BOS) token that delimits documents @@ -165,6 +164,9 @@ class RustBPETokenizer: def __init__(self, enc, bos_token): self.enc = enc + # instance-level cache for special token ids; replaces lru_cache on the + # method (which kept a strong ref to self in the function-level cache) + self._special_id_cache: dict[str, int] = {} self.bos_token_id = self.encode_special(bos_token) @classmethod @@ -215,9 +217,13 @@ class RustBPETokenizer: def id_to_token(self, id): return self.enc.decode([id]) - @lru_cache(maxsize=32) def encode_special(self, text): - return self.enc.encode_single_token(text) + cached = self._special_id_cache.get(text) + if cached is not None: + return cached + v = self.enc.encode_single_token(text) + self._special_id_cache[text] = v + return v def get_bos_token_id(self): return self.bos_token_id @@ -239,8 +245,8 @@ class RustBPETokenizer: elif isinstance(text, list): ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads) if prepend is not None: - for ids_row in ids: - ids_row.insert(0, prepend_id) # TODO: same + # avoid O(n) shift per row that insert(0, ...) does + ids = [[prepend_id, *row] for row in ids] if append is not None: for ids_row in ids: ids_row.append(append_id) From bacd7efc06ff1f0cdb28f82d4ca9e669ecd86770 Mon Sep 17 00:00:00 2001 From: EFE AYDIN Date: Fri, 15 May 2026 23:24:21 +0300 Subject: [PATCH 3/5] Update tokenizer.py --- nanochat/tokenizer.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/nanochat/tokenizer.py b/nanochat/tokenizer.py index 6ee0603e..1cf23065 100644 --- a/nanochat/tokenizer.py +++ b/nanochat/tokenizer.py @@ -4,6 +4,31 @@ BPE Tokenizer in the style of GPT-4. Two implementations are available: 1) HuggingFace Tokenizer that can do both training and inference but is really confusing 2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference + +Patch 1 — encode_special lru_cache → instance dict +__init__: +pythondef __init__(self, enc, bos_token): + self.enc = enc + self._special_id_cache: dict[str, int] = {} + self.bos_token_id = self.encode_special(bos_token) +encode_special (decorator'ı kaldır): +pythondef encode_special(self, text): + cached = self._special_id_cache.get(text) + if cached is not None: + return cached + v = self.enc.encode_single_token(text) + self._special_id_cache[text] = v + return v +Signature aynı, cache davranışı aynı (lookup → hit/miss), instance ölünce cache da ölüyor. Caller hiçbir şey fark etmez. +Patch 2 — insert(0) O(n) shift'i kaldır +RustBPETokenizer.encode, batch dalı: +pythonelif isinstance(text, list): + ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads) + if prepend is not None: + ids = [[prepend_id, *row] for row in ids] + if append is not None: + for ids_row in ids: + ids_row.append(append_id) """ import os From 75971f2fe2421881ea66c769b47feb476b566566 Mon Sep 17 00:00:00 2001 From: EFE AYDIN Date: Sun, 17 May 2026 00:47:40 +0300 Subject: [PATCH 4/5] Update tokenizer.py --- nanochat/tokenizer.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/nanochat/tokenizer.py b/nanochat/tokenizer.py index 1cf23065..6ee0603e 100644 --- a/nanochat/tokenizer.py +++ b/nanochat/tokenizer.py @@ -4,31 +4,6 @@ BPE Tokenizer in the style of GPT-4. Two implementations are available: 1) HuggingFace Tokenizer that can do both training and inference but is really confusing 2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference - -Patch 1 — encode_special lru_cache → instance dict -__init__: -pythondef __init__(self, enc, bos_token): - self.enc = enc - self._special_id_cache: dict[str, int] = {} - self.bos_token_id = self.encode_special(bos_token) -encode_special (decorator'ı kaldır): -pythondef encode_special(self, text): - cached = self._special_id_cache.get(text) - if cached is not None: - return cached - v = self.enc.encode_single_token(text) - self._special_id_cache[text] = v - return v -Signature aynı, cache davranışı aynı (lookup → hit/miss), instance ölünce cache da ölüyor. Caller hiçbir şey fark etmez. -Patch 2 — insert(0) O(n) shift'i kaldır -RustBPETokenizer.encode, batch dalı: -pythonelif isinstance(text, list): - ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads) - if prepend is not None: - ids = [[prepend_id, *row] for row in ids] - if append is not None: - for ids_row in ids: - ids_row.append(append_id) """ import os From 34b2b0d003a6158031d5c59aa0da9107802eaa86 Mon Sep 17 00:00:00 2001 From: EFE AYDIN Date: Sun, 17 May 2026 01:10:21 +0300 Subject: [PATCH 5/5] Update flash_attention.py --- nanochat/flash_attention.py | 74 ++++++++++--------------------------- 1 file changed, 20 insertions(+), 54 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 5189a4d0..af2aee32 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -63,36 +63,10 @@ def _resolve_use_fa3(): USE_FA3 = _resolve_use_fa3() -# ============================================================================= -# Mask cache for chunked inference (Tq != Tk, Tq > 1) -# ============================================================================= -_MASK_CACHE: dict = {} -_MASK_CACHE_MAX = 32 - - -def _get_chunk_mask(device, Tq, Tk, window): - """Cached causal (+sliding window) mask for chunk inference.""" - key = (device.type, device.index, Tq, Tk, window) - m = _MASK_CACHE.get(key) - if m is not None: - return m - - row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1) - col_idx = torch.arange(Tk, device=device).unsqueeze(0) - m = col_idx <= row_idx - if window >= 0 and window < Tk: - m = m & ((row_idx - col_idx) <= window) - - if len(_MASK_CACHE) >= _MASK_CACHE_MAX: - _MASK_CACHE.clear() - _MASK_CACHE[key] = m - return m - - # ============================================================================= # SDPA helpers # ============================================================================= -def _sdpa_attention(q, k, v, window_size, enable_gqa, causal=True): +def _sdpa_attention(q, k, v, window_size, enable_gqa): """ SDPA attention with sliding window support. q, k, v are (B, H, T, D) format. @@ -103,7 +77,7 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa, causal=True): # Full context, same length if (window < 0 or window >= Tq) and Tq == Tk: - return F.scaled_dot_product_attention(q, k, v, is_causal=causal, enable_gqa=enable_gqa) + return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa) # Single token generation if Tq == 1: @@ -114,10 +88,18 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa, causal=True): v = v[:, :, start:, :] return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) - # Chunk inference (Tq > 1, Tq != Tk): use cached explicit bool mask. - mask = _get_chunk_mask(q.device, Tq, Tk, window) - return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) + # Need explicit mask for sliding window/chunk inference + device = q.device + # For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask + row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1) + col_idx = torch.arange(Tk, device=device).unsqueeze(0) + mask = col_idx <= row_idx + # sliding window (left) + if window >= 0 and window < Tk: + mask = mask & ((row_idx - col_idx) <= window) + + return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) # ============================================================================= # Public API: Same interface as FA3 @@ -142,7 +124,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): k = k.transpose(1, 2) v = v.transpose(1, 2) enable_gqa = q.size(1) != k.size(1) - y = _sdpa_attention(q, k, v, window_size, enable_gqa, causal=causal) + y = _sdpa_attention(q, k, v, window_size, enable_gqa) return y.transpose(1, 2) # back to (B, T, H, D) @@ -157,8 +139,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N q: Queries, shape (B, T_new, H, D) k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D) k, v: New keys/values to insert, shape (B, T_new, H_kv, D) - cache_seqlens: Current position in cache. Either an int (fast path, no - GPU->CPU sync) or a tensor of shape (B,) int32 (FA3-compatible). + cache_seqlens: Current position in cache, shape (B,) int32 causal: Whether to use causal masking window_size: (left, right) sliding window. -1 means unlimited. @@ -173,32 +154,17 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N # SDPA fallback: manually manage KV cache B, T_new, H, D = q.shape - - # Avoid GPU->CPU sync if caller passes a Python int. - if isinstance(cache_seqlens, int): - pos = cache_seqlens - elif isinstance(cache_seqlens, torch.Tensor): - pos = int(cache_seqlens[0].item()) # assume uniform position across batch - else: - pos = int(cache_seqlens) + pos = cache_seqlens[0].item() # assume uniform position across batch # Insert new k, v into cache (in-place, matching FA3 behavior) if k is not None and v is not None: k_cache[:, pos:pos+T_new, :, :] = k v_cache[:, pos:pos+T_new, :, :] = v + # Get full cache up to current position + new tokens end_pos = pos + T_new - - # Sliding-window single-token decode: trim cache slice early so SDPA sees - # only the window instead of the full prefix. - window = window_size[0] - if T_new == 1 and 0 <= window < end_pos: - start = max(0, end_pos - (window + 1)) - k_full = k_cache[:, start:end_pos, :, :] - v_full = v_cache[:, start:end_pos, :, :] - else: - k_full = k_cache[:, :end_pos, :, :] - v_full = v_cache[:, :end_pos, :, :] + k_full = k_cache[:, :end_pos, :, :] + v_full = v_cache[:, :end_pos, :, :] # Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D) q_sdpa = q.transpose(1, 2) @@ -206,7 +172,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N v_sdpa = v_full.transpose(1, 2) enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) - y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa, causal=causal) + y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) return y_sdpa.transpose(1, 2) # back to (B, T, H, D)