mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-07 01:40:30 +00:00
feat: Introduce BOS-aligned bestfit distributed dataloaders, Flash Attention. (1) Pre-convert tokenized documents to tensors in
dataloader.py buffer to avoid repeated torch.tensor() calls; (2) Added LRU cache for sliding window masks in flash_attention.py to avoid recreating masks on every call.
This commit is contained in:
parent
542beb0c8c
commit
005daea668
|
|
@ -104,8 +104,9 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
|||
nonlocal pq_idx, rg_idx, epoch
|
||||
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
|
||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
||||
# Pre-convert to tensors once during buffering to avoid repeated torch.tensor() in inner loop
|
||||
for tokens in token_lists:
|
||||
doc_buffer.append(tokens)
|
||||
doc_buffer.append(torch.tensor(tokens, dtype=torch.long))
|
||||
|
||||
# Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
|
||||
# This gives us contiguous views and a single HtoD transfer
|
||||
|
|
@ -128,25 +129,25 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
|||
|
||||
remaining = row_capacity - pos
|
||||
|
||||
# Find largest doc that fits entirely
|
||||
# Find largest doc that fits entirely (doc is now a tensor)
|
||||
best_idx = -1
|
||||
best_len = 0
|
||||
for i, doc in enumerate(doc_buffer):
|
||||
doc_len = len(doc)
|
||||
doc_len = doc.size(0)
|
||||
if doc_len <= remaining and doc_len > best_len:
|
||||
best_idx = i
|
||||
best_len = doc_len
|
||||
|
||||
if best_idx >= 0:
|
||||
doc = doc_buffer.pop(best_idx)
|
||||
doc_len = len(doc)
|
||||
row_buffer[row_idx, pos:pos + doc_len] = torch.tensor(doc, dtype=torch.long)
|
||||
doc_len = doc.size(0)
|
||||
row_buffer[row_idx, pos:pos + doc_len] = doc # Direct tensor copy, no conversion
|
||||
pos += doc_len
|
||||
else:
|
||||
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
|
||||
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
|
||||
shortest_idx = min(range(len(doc_buffer)), key=lambda i: doc_buffer[i].size(0))
|
||||
doc = doc_buffer.pop(shortest_idx)
|
||||
row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
|
||||
row_buffer[row_idx, pos:pos + remaining] = doc[:remaining] # Tensor slice, no conversion
|
||||
pos += remaining
|
||||
|
||||
# Copy to pinned CPU buffer, then single HtoD transfer
|
||||
|
|
|
|||
|
|
@ -58,6 +58,39 @@ def _use_fa3():
|
|||
# =============================================================================
|
||||
# SDPA helpers
|
||||
# =============================================================================
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_sliding_window_mask(Tq: int, Tk: int, window: int, device_index: int):
|
||||
"""
|
||||
Create and cache a sliding window attention mask.
|
||||
|
||||
Args:
|
||||
Tq: Query sequence length
|
||||
Tk: Key sequence length
|
||||
window: Sliding window size (-1 for full context)
|
||||
device_index: CUDA device index (0 for CPU/MPS, else cuda device id)
|
||||
|
||||
Returns:
|
||||
Boolean mask tensor of shape (Tq, Tk)
|
||||
"""
|
||||
if device_index == -1:
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = torch.device(f"cuda:{device_index}")
|
||||
|
||||
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
|
||||
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
|
||||
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
|
||||
mask = col_idx <= row_idx
|
||||
|
||||
# sliding window (left)
|
||||
if window >= 0 and window < Tk:
|
||||
mask = mask & ((row_idx - col_idx) <= window)
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
||||
"""
|
||||
SDPA attention with sliding window support.
|
||||
|
|
@ -80,16 +113,10 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
|||
v = v[:, :, start:, :]
|
||||
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
||||
|
||||
# Need explicit mask for sliding window/chunk inference
|
||||
# Need explicit mask for sliding window/chunk inference - use cached mask
|
||||
device = q.device
|
||||
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
|
||||
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
|
||||
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
|
||||
mask = col_idx <= row_idx
|
||||
|
||||
# sliding window (left)
|
||||
if window >= 0 and window < Tk:
|
||||
mask = mask & ((row_idx - col_idx) <= window)
|
||||
device_index = device.index if device.type == "cuda" else -1
|
||||
mask = _get_sliding_window_mask(Tq, Tk, window, device_index)
|
||||
|
||||
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user