Compare commits

...

4 Commits

Author SHA1 Message Date
Emanuele
9e95480bbd
Merge 005daea668 into 1ec0a34779 2026-02-08 20:28:23 +02:00
Andrej Karpathy
1ec0a34779 at 28 and above we start to need batch size 8 2026-02-08 18:26:34 +00:00
Andrej Karpathy
ff46300720 tune miniseries just a bit, fairly cosmetic, keep to even depths where the math works out nicely in model sizing 2026-02-08 17:54:12 +00:00
Emanuele
005daea668 feat: Introduce BOS-aligned bestfit distributed dataloaders, Flash Attention. (1) Pre-convert tokenized documents to tensors in
dataloader.py
 buffer to avoid repeated torch.tensor() calls; (2) Added LRU cache for sliding window masks in
flash_attention.py
 to avoid recreating masks on every call.
2026-02-04 11:26:48 +01:00
3 changed files with 55 additions and 19 deletions

View File

@ -104,8 +104,9 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
nonlocal pq_idx, rg_idx, epoch
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
# Pre-convert to tensors once during buffering to avoid repeated torch.tensor() in inner loop
for tokens in token_lists:
doc_buffer.append(tokens)
doc_buffer.append(torch.tensor(tokens, dtype=torch.long))
# Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
# This gives us contiguous views and a single HtoD transfer
@ -128,25 +129,25 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
remaining = row_capacity - pos
# Find largest doc that fits entirely
# Find largest doc that fits entirely (doc is now a tensor)
best_idx = -1
best_len = 0
for i, doc in enumerate(doc_buffer):
doc_len = len(doc)
doc_len = doc.size(0)
if doc_len <= remaining and doc_len > best_len:
best_idx = i
best_len = doc_len
if best_idx >= 0:
doc = doc_buffer.pop(best_idx)
doc_len = len(doc)
row_buffer[row_idx, pos:pos + doc_len] = torch.tensor(doc, dtype=torch.long)
doc_len = doc.size(0)
row_buffer[row_idx, pos:pos + doc_len] = doc # Direct tensor copy, no conversion
pos += doc_len
else:
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
shortest_idx = min(range(len(doc_buffer)), key=lambda i: doc_buffer[i].size(0))
doc = doc_buffer.pop(shortest_idx)
row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
row_buffer[row_idx, pos:pos + remaining] = doc[:remaining] # Tensor slice, no conversion
pos += remaining
# Copy to pinned CPU buffer, then single HtoD transfer

View File

@ -58,6 +58,39 @@ def _use_fa3():
# =============================================================================
# SDPA helpers
# =============================================================================
from functools import lru_cache
@lru_cache(maxsize=32)
def _get_sliding_window_mask(Tq: int, Tk: int, window: int, device_index: int):
"""
Create and cache a sliding window attention mask.
Args:
Tq: Query sequence length
Tk: Key sequence length
window: Sliding window size (-1 for full context)
device_index: CUDA device index (0 for CPU/MPS, else cuda device id)
Returns:
Boolean mask tensor of shape (Tq, Tk)
"""
if device_index == -1:
device = torch.device("cpu")
else:
device = torch.device(f"cuda:{device_index}")
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
mask = col_idx <= row_idx
# sliding window (left)
if window >= 0 and window < Tk:
mask = mask & ((row_idx - col_idx) <= window)
return mask
def _sdpa_attention(q, k, v, window_size, enable_gqa):
"""
SDPA attention with sliding window support.
@ -80,16 +113,10 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
v = v[:, :, start:, :]
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
# Need explicit mask for sliding window/chunk inference
# Need explicit mask for sliding window/chunk inference - use cached mask
device = q.device
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
mask = col_idx <= row_idx
# sliding window (left)
if window >= 0 and window < Tk:
mask = mask & ((row_idx - col_idx) <= window)
device_index = device.index if device.type == "cuda" else -1
mask = _get_sliding_window_mask(Tq, Tk, window, device_index)
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)

View File

@ -28,7 +28,7 @@ fi
# Series name: from arg, env var, or default to today's date (e.g., jan11)
SERIES_NAME="${1:-${SERIES_NAME:-$(date +%b%d | tr '[:upper:]' '[:lower:]')}}"
# Depths to train (the "miniseries")
DEPTHS=(10 11 12 13 14 15 16 17 18 19 20)
DEPTHS=(12 14 16 18 20 22 24 26)
# Hardware
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
# Logging
@ -57,8 +57,15 @@ for d in "${DEPTHS[@]}"; do
TAG="${SERIES_NAME}_miniseries_d${d}"
START_TIME=$(date +%s)
# Train the model with natural horizon (target_param_data_ratio default)
# No --target-flops, let it use the default ratio from base_train
# Reduce --device-batch-size to avoid OOM at larger depths
if [ $d -ge 28 ]; then
DEVICE_BATCH_SIZE_ARG="--device-batch-size=8"
elif [ $d -ge 20 ]; then
DEVICE_BATCH_SIZE_ARG="--device-batch-size=16"
else
DEVICE_BATCH_SIZE_ARG="--device-batch-size=32"
fi
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
--depth=$d \
--run="${WANDB_RUN}_d${d}" \
@ -67,6 +74,7 @@ for d in "${DEPTHS[@]}"; do
--core-metric-max-per-task=-1 \
--sample-every=-1 \
--save-every=-1 \
$DEVICE_BATCH_SIZE_ARG \
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
END_TIME=$(date +%s)