diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 125625f..c97e6d5 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -104,8 +104,9 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit( nonlocal pq_idx, rg_idx, epoch doc_batch, (pq_idx, rg_idx, epoch) = next(batches) token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads) + # Pre-convert to tensors once during buffering to avoid repeated torch.tensor() in inner loop for tokens in token_lists: - doc_buffer.append(tokens) + doc_buffer.append(torch.tensor(tokens, dtype=torch.long)) # Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)] # This gives us contiguous views and a single HtoD transfer @@ -128,25 +129,25 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit( remaining = row_capacity - pos - # Find largest doc that fits entirely + # Find largest doc that fits entirely (doc is now a tensor) best_idx = -1 best_len = 0 for i, doc in enumerate(doc_buffer): - doc_len = len(doc) + doc_len = doc.size(0) if doc_len <= remaining and doc_len > best_len: best_idx = i best_len = doc_len if best_idx >= 0: doc = doc_buffer.pop(best_idx) - doc_len = len(doc) - row_buffer[row_idx, pos:pos + doc_len] = torch.tensor(doc, dtype=torch.long) + doc_len = doc.size(0) + row_buffer[row_idx, pos:pos + doc_len] = doc # Direct tensor copy, no conversion pos += doc_len else: # No doc fits - crop shortest in buffer to fill remaining and minimize waste - shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i])) + shortest_idx = min(range(len(doc_buffer)), key=lambda i: doc_buffer[i].size(0)) doc = doc_buffer.pop(shortest_idx) - row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long) + row_buffer[row_idx, pos:pos + remaining] = doc[:remaining] # Tensor slice, no conversion pos += remaining # Copy to pinned CPU buffer, then single HtoD transfer diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 89ca42b..e70f2f0 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -58,6 +58,39 @@ def _use_fa3(): # ============================================================================= # SDPA helpers # ============================================================================= +from functools import lru_cache + +@lru_cache(maxsize=32) +def _get_sliding_window_mask(Tq: int, Tk: int, window: int, device_index: int): + """ + Create and cache a sliding window attention mask. + + Args: + Tq: Query sequence length + Tk: Key sequence length + window: Sliding window size (-1 for full context) + device_index: CUDA device index (0 for CPU/MPS, else cuda device id) + + Returns: + Boolean mask tensor of shape (Tq, Tk) + """ + if device_index == -1: + device = torch.device("cpu") + else: + device = torch.device(f"cuda:{device_index}") + + # For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask + row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1) + col_idx = torch.arange(Tk, device=device).unsqueeze(0) + mask = col_idx <= row_idx + + # sliding window (left) + if window >= 0 and window < Tk: + mask = mask & ((row_idx - col_idx) <= window) + + return mask + + def _sdpa_attention(q, k, v, window_size, enable_gqa): """ SDPA attention with sliding window support. @@ -80,16 +113,10 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): v = v[:, :, start:, :] return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) - # Need explicit mask for sliding window/chunk inference + # Need explicit mask for sliding window/chunk inference - use cached mask device = q.device - # For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask - row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1) - col_idx = torch.arange(Tk, device=device).unsqueeze(0) - mask = col_idx <= row_idx - - # sliding window (left) - if window >= 0 and window < Tk: - mask = mask & ((row_idx - col_idx) <= window) + device_index = device.index if device.type == "cuda" else -1 + mask = _get_sliding_window_mask(Tq, Tk, window, device_index) return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)