mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-02 14:00:45 +00:00
Compare commits
4 Commits
2134348a5c
...
9e95480bbd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e95480bbd | ||
|
|
1ec0a34779 | ||
|
|
ff46300720 | ||
|
|
005daea668 |
|
|
@ -104,8 +104,9 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
|||
nonlocal pq_idx, rg_idx, epoch
|
||||
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
|
||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
||||
# Pre-convert to tensors once during buffering to avoid repeated torch.tensor() in inner loop
|
||||
for tokens in token_lists:
|
||||
doc_buffer.append(tokens)
|
||||
doc_buffer.append(torch.tensor(tokens, dtype=torch.long))
|
||||
|
||||
# Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
|
||||
# This gives us contiguous views and a single HtoD transfer
|
||||
|
|
@ -128,25 +129,25 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
|||
|
||||
remaining = row_capacity - pos
|
||||
|
||||
# Find largest doc that fits entirely
|
||||
# Find largest doc that fits entirely (doc is now a tensor)
|
||||
best_idx = -1
|
||||
best_len = 0
|
||||
for i, doc in enumerate(doc_buffer):
|
||||
doc_len = len(doc)
|
||||
doc_len = doc.size(0)
|
||||
if doc_len <= remaining and doc_len > best_len:
|
||||
best_idx = i
|
||||
best_len = doc_len
|
||||
|
||||
if best_idx >= 0:
|
||||
doc = doc_buffer.pop(best_idx)
|
||||
doc_len = len(doc)
|
||||
row_buffer[row_idx, pos:pos + doc_len] = torch.tensor(doc, dtype=torch.long)
|
||||
doc_len = doc.size(0)
|
||||
row_buffer[row_idx, pos:pos + doc_len] = doc # Direct tensor copy, no conversion
|
||||
pos += doc_len
|
||||
else:
|
||||
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
|
||||
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
|
||||
shortest_idx = min(range(len(doc_buffer)), key=lambda i: doc_buffer[i].size(0))
|
||||
doc = doc_buffer.pop(shortest_idx)
|
||||
row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
|
||||
row_buffer[row_idx, pos:pos + remaining] = doc[:remaining] # Tensor slice, no conversion
|
||||
pos += remaining
|
||||
|
||||
# Copy to pinned CPU buffer, then single HtoD transfer
|
||||
|
|
|
|||
|
|
@ -58,6 +58,39 @@ def _use_fa3():
|
|||
# =============================================================================
|
||||
# SDPA helpers
|
||||
# =============================================================================
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_sliding_window_mask(Tq: int, Tk: int, window: int, device_index: int):
|
||||
"""
|
||||
Create and cache a sliding window attention mask.
|
||||
|
||||
Args:
|
||||
Tq: Query sequence length
|
||||
Tk: Key sequence length
|
||||
window: Sliding window size (-1 for full context)
|
||||
device_index: CUDA device index (0 for CPU/MPS, else cuda device id)
|
||||
|
||||
Returns:
|
||||
Boolean mask tensor of shape (Tq, Tk)
|
||||
"""
|
||||
if device_index == -1:
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = torch.device(f"cuda:{device_index}")
|
||||
|
||||
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
|
||||
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
|
||||
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
|
||||
mask = col_idx <= row_idx
|
||||
|
||||
# sliding window (left)
|
||||
if window >= 0 and window < Tk:
|
||||
mask = mask & ((row_idx - col_idx) <= window)
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
||||
"""
|
||||
SDPA attention with sliding window support.
|
||||
|
|
@ -80,16 +113,10 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
|||
v = v[:, :, start:, :]
|
||||
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
||||
|
||||
# Need explicit mask for sliding window/chunk inference
|
||||
# Need explicit mask for sliding window/chunk inference - use cached mask
|
||||
device = q.device
|
||||
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
|
||||
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
|
||||
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
|
||||
mask = col_idx <= row_idx
|
||||
|
||||
# sliding window (left)
|
||||
if window >= 0 and window < Tk:
|
||||
mask = mask & ((row_idx - col_idx) <= window)
|
||||
device_index = device.index if device.type == "cuda" else -1
|
||||
mask = _get_sliding_window_mask(Tq, Tk, window, device_index)
|
||||
|
||||
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ fi
|
|||
# Series name: from arg, env var, or default to today's date (e.g., jan11)
|
||||
SERIES_NAME="${1:-${SERIES_NAME:-$(date +%b%d | tr '[:upper:]' '[:lower:]')}}"
|
||||
# Depths to train (the "miniseries")
|
||||
DEPTHS=(10 11 12 13 14 15 16 17 18 19 20)
|
||||
DEPTHS=(12 14 16 18 20 22 24 26)
|
||||
# Hardware
|
||||
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
||||
# Logging
|
||||
|
|
@ -57,8 +57,15 @@ for d in "${DEPTHS[@]}"; do
|
|||
TAG="${SERIES_NAME}_miniseries_d${d}"
|
||||
START_TIME=$(date +%s)
|
||||
|
||||
# Train the model with natural horizon (target_param_data_ratio default)
|
||||
# No --target-flops, let it use the default ratio from base_train
|
||||
# Reduce --device-batch-size to avoid OOM at larger depths
|
||||
if [ $d -ge 28 ]; then
|
||||
DEVICE_BATCH_SIZE_ARG="--device-batch-size=8"
|
||||
elif [ $d -ge 20 ]; then
|
||||
DEVICE_BATCH_SIZE_ARG="--device-batch-size=16"
|
||||
else
|
||||
DEVICE_BATCH_SIZE_ARG="--device-batch-size=32"
|
||||
fi
|
||||
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
|
||||
--depth=$d \
|
||||
--run="${WANDB_RUN}_d${d}" \
|
||||
|
|
@ -67,6 +74,7 @@ for d in "${DEPTHS[@]}"; do
|
|||
--core-metric-max-per-task=-1 \
|
||||
--sample-every=-1 \
|
||||
--save-every=-1 \
|
||||
$DEVICE_BATCH_SIZE_ARG \
|
||||
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
|
||||
|
||||
END_TIME=$(date +%s)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user