This commit is contained in:
Chris McCormick 2026-01-31 13:25:42 -05:00 committed by GitHub
commit b273607399
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 77 additions and 41 deletions

View File

@ -144,66 +144,92 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
row_capacity = T + 1
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
bos_token = tokenizer.get_bos_token_id()
doc_buffer = []
pq_idx, rg_idx, epoch = 0, 0, 1
# Token pool: single tensor holding all buffered tokens
# Documents tracked as (start, length) tuples
pool = torch.empty(buffer_size * 512, dtype=torch.long)
pool_end = 0
docs = [] # [(start, length), ...]
def compact_pool():
"""Shift active documents to front of pool, reclaiming space."""
nonlocal pool_end
if not docs:
pool_end = 0
return
write_pos = 0
for i, (start, length) in enumerate(docs):
if start != write_pos:
pool[write_pos:write_pos + length] = pool[start:start + length].clone()
docs[i] = (write_pos, length)
write_pos += length
pool_end = write_pos
def refill_buffer():
nonlocal pq_idx, rg_idx, epoch
"""Retrieve more docs and add them to the pool"""
nonlocal pq_idx, rg_idx, epoch, pool, pool_end
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
# Number of new tokens to store
total_new = sum(len(t) for t in token_lists)
# If there's not enough space at the end,
if pool_end + total_new > pool.size(0):
compact_pool() # Try compacting first.
# If still not enough,
if pool_end + total_new > pool.size(0):
# Allocate a new, larger pool.
new_size = max(pool.size(0) * 2, pool_end + total_new)
new_pool = torch.empty(new_size, dtype=torch.long)
new_pool[:pool_end] = pool[:pool_end]
pool = new_pool
# Write tokens to pool
for tokens in token_lists:
doc_buffer.append(tokens)
n = len(tokens)
pool[pool_end:pool_end + n] = torch.tensor(tokens, dtype=torch.long)
docs.append((pool_end, n))
pool_end += n
# Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
# This gives us contiguous views and a single HtoD transfer
# Pre-allocate buffers once
use_cuda = device == "cuda"
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU)
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer
cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience
cpu_targets = cpu_buffer[B * T:].view(B, T)
inputs = gpu_buffer[:B * T].view(B, T)
targets = gpu_buffer[B * T:].view(B, T)
row_buffer = torch.empty((B, row_capacity), dtype=torch.long)
inputs = torch.empty((B, T), dtype=torch.long, device=device)
targets = torch.empty((B, T), dtype=torch.long, device=device)
while True:
rows = []
for _ in range(B):
row = []
while len(row) < row_capacity:
for row_idx in range(B):
col = 0
while col < row_capacity:
# Ensure buffer has documents
while len(doc_buffer) < buffer_size:
while len(docs) < buffer_size:
refill_buffer()
remaining = row_capacity - len(row)
remaining = row_capacity - col
# Find largest doc that fits entirely
best_idx = -1
best_len = 0
for i, doc in enumerate(doc_buffer):
doc_len = len(doc)
if doc_len <= remaining and doc_len > best_len:
for i, (start, length) in enumerate(docs):
if length <= remaining and length > best_len:
best_idx = i
best_len = doc_len
best_len = length
if best_idx >= 0:
doc = doc_buffer.pop(best_idx)
row.extend(doc)
start, length = docs.pop(best_idx)
row_buffer[row_idx, col:col + length] = pool[start:start + length]
col += length
else:
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
doc = doc_buffer.pop(shortest_idx)
row.extend(doc[:remaining])
# No doc fits - crop shortest to fill remaining
shortest_idx = min(range(len(docs)), key=lambda i: docs[i][1])
start, length = docs.pop(shortest_idx)
row_buffer[row_idx, col:col + remaining] = pool[start:start + remaining]
col += remaining
rows.append(row[:row_capacity])
# Convert rows to tensor and copy slices to pinned buffer (CPU work)
row_data = torch.tensor(rows, dtype=torch.long) # [B, T+1], temporary
cpu_inputs.copy_(row_data[:, :-1])
cpu_targets.copy_(row_data[:, 1:])
# Copy to GPU
inputs.copy_(row_buffer[:, :-1], non_blocking=use_cuda)
targets.copy_(row_buffer[:, 1:], non_blocking=use_cuda)
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
# Single HtoD copy into persistent GPU buffer and yield
gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda)
yield inputs, targets, state_dict
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):

View File

@ -217,7 +217,7 @@ class MuonAdamW(torch.optim.Optimizer):
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
adamw_step_fused(
p, grad, exp_avg, exp_avg_sq,
p.data, grad, exp_avg, exp_avg_sq,
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
)

View File

@ -4,22 +4,32 @@
# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour.
# 1) Example launch (simplest):
# bash speedrun.sh
# bash runs/speedrun.sh
# 2) Example launch in a screen session (because the run takes ~4 hours):
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
# screen -L -Logfile speedrun.log -S speedrun bash runs/speedrun.sh
# 3) Example launch with wandb logging, but see below for setting up wandb first:
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash runs/speedrun.sh
# Default intermediate artifacts directory is in ~/.cache/nanochat
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
# -----------------------------------------------------------------------------
# System dependencies (Python dev headers needed for Triton/torch compilation)
if ! dpkg -s python3-dev &> /dev/null; then
echo "Installing python3-dev (required for Python.h)..."
sudo apt-get update && sudo apt-get install -y python3-dev
fi
# -----------------------------------------------------------------------------
# Python venv setup with uv
# install uv (if not already installed)
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
# add uv to PATH (the installer puts it in ~/.local/bin)
export PATH="$HOME/.local/bin:$PATH"
# create a .venv local virtual environment (if it doesn't exist)
[ -d ".venv" ] || uv venv
# install the repo dependencies
@ -81,7 +91,7 @@ wait $DATASET_DOWNLOAD_PID
NPROC_PER_NODE=8
# pretrain the d20 model
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target-param-data-ratio=20 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=24 --target-param-data-ratio=12 --device-batch-size=16 --run=$WANDB_RUN
# evaluate the model on a larger chunk of train/val data and draw some samples
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
# evaluate the model on CORE tasks