feat: Introduce BOS-aligned bestfit distributed dataloaders, Flash Attention. (1) Pre-convert tokenized documents to tensors in

dataloader.py
 buffer to avoid repeated torch.tensor() calls; (2) Added LRU cache for sliding window masks in
flash_attention.py
 to avoid recreating masks on every call.
This commit is contained in:
Emanuele 2026-02-04 11:26:48 +01:00
parent 542beb0c8c
commit 005daea668
2 changed files with 44 additions and 16 deletions

View File

@ -104,8 +104,9 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
nonlocal pq_idx, rg_idx, epoch
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
# Pre-convert to tensors once during buffering to avoid repeated torch.tensor() in inner loop
for tokens in token_lists:
doc_buffer.append(tokens)
doc_buffer.append(torch.tensor(tokens, dtype=torch.long))
# Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
# This gives us contiguous views and a single HtoD transfer
@ -128,25 +129,25 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
remaining = row_capacity - pos
# Find largest doc that fits entirely
# Find largest doc that fits entirely (doc is now a tensor)
best_idx = -1
best_len = 0
for i, doc in enumerate(doc_buffer):
doc_len = len(doc)
doc_len = doc.size(0)
if doc_len <= remaining and doc_len > best_len:
best_idx = i
best_len = doc_len
if best_idx >= 0:
doc = doc_buffer.pop(best_idx)
doc_len = len(doc)
row_buffer[row_idx, pos:pos + doc_len] = torch.tensor(doc, dtype=torch.long)
doc_len = doc.size(0)
row_buffer[row_idx, pos:pos + doc_len] = doc # Direct tensor copy, no conversion
pos += doc_len
else:
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
shortest_idx = min(range(len(doc_buffer)), key=lambda i: doc_buffer[i].size(0))
doc = doc_buffer.pop(shortest_idx)
row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
row_buffer[row_idx, pos:pos + remaining] = doc[:remaining] # Tensor slice, no conversion
pos += remaining
# Copy to pinned CPU buffer, then single HtoD transfer

View File

@ -58,6 +58,39 @@ def _use_fa3():
# =============================================================================
# SDPA helpers
# =============================================================================
from functools import lru_cache
@lru_cache(maxsize=32)
def _get_sliding_window_mask(Tq: int, Tk: int, window: int, device_index: int):
"""
Create and cache a sliding window attention mask.
Args:
Tq: Query sequence length
Tk: Key sequence length
window: Sliding window size (-1 for full context)
device_index: CUDA device index (0 for CPU/MPS, else cuda device id)
Returns:
Boolean mask tensor of shape (Tq, Tk)
"""
if device_index == -1:
device = torch.device("cpu")
else:
device = torch.device(f"cuda:{device_index}")
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
mask = col_idx <= row_idx
# sliding window (left)
if window >= 0 and window < Tk:
mask = mask & ((row_idx - col_idx) <= window)
return mask
def _sdpa_attention(q, k, v, window_size, enable_gqa):
"""
SDPA attention with sliding window support.
@ -80,16 +113,10 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
v = v[:, :, start:, :]
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
# Need explicit mask for sliding window/chunk inference
# Need explicit mask for sliding window/chunk inference - use cached mask
device = q.device
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
mask = col_idx <= row_idx
# sliding window (left)
if window >= 0 and window < Tk:
mask = mask & ((row_idx - col_idx) <= window)
device_index = device.index if device.type == "cuda" else -1
mask = _get_sliding_window_mask(Tq, Tk, window, device_index)
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)