nanochat/nanochat/dataloader.py

167 lines
7.3 KiB
Python

"""
Distributed dataloaders for pretraining.
BOS-aligned bestfit:
- Every row starts with BOS token
- Documents packed using best-fit algorithm to minimize cropping
- When no document fits remaining space, crops a document to fill exactly
- 100% utilization (no padding), ~35% tokens cropped at T=2048
Compared to the original tokenizing_distributed_data_loader:
BOS-aligned loses ~35% of tokens to cropping, but ensures that
there are fewer "confusing" tokens in the train/val batches as every token can
now attend back to the BOS token and sees the full context of the document.
Fallback to the original if you have very limited data AND long documents:
https://github.com/karpathy/nanochat/blob/3c3a3d7/nanochat/dataloader.py#L78-L117
"""
import torch
import pyarrow.parquet as pq
from nanochat.common import get_dist_info
from nanochat.dataset import list_parquet_files
def _document_batches(split, resume_state_dict, tokenizer_batch_size):
"""
Infinite iterator over document batches (list of text strings) from parquet files.
Handles DDP sharding and approximate resume. Each yield is (text_batch, (pq_idx, rg_idx, epoch))
where text_batch is a list of document strings, indices track position for resumption,
and epoch counts how many times we've cycled through the dataset (starts at 1).
"""
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
warn_on_legacy = ddp_rank == 0 and split == "train" # rank 0 on train split will warn on legacy
parquet_paths = list_parquet_files(warn_on_legacy=warn_on_legacy)
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1
first_pass = True
pq_idx = resume_pq_idx
epoch = resume_epoch
while True: # iterate infinitely (multi-epoch)
pq_idx = resume_pq_idx if first_pass else 0
while pq_idx < len(parquet_paths):
filepath = parquet_paths[pq_idx]
pf = pq.ParquetFile(filepath)
# Start from resume point if resuming on same file, otherwise from DDP rank
if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
base_idx = resume_rg_idx // ddp_world_size
base_idx += 1 # advance by 1 so we don't repeat data after resuming
rg_idx = base_idx * ddp_world_size + ddp_rank
if rg_idx >= pf.num_row_groups:
pq_idx += 1
continue
resume_rg_idx = None # only do this once
else:
rg_idx = ddp_rank
while rg_idx < pf.num_row_groups:
rg = pf.read_row_group(rg_idx)
batch = rg.column('text').to_pylist()
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx, epoch)
rg_idx += ddp_world_size
pq_idx += 1
first_pass = False
epoch += 1
def tokenizing_distributed_data_loader_with_state_bos_bestfit(
tokenizer, B, T, split,
tokenizer_threads=4, tokenizer_batch_size=128,
device="cuda", resume_state_dict=None,
buffer_size=1000
):
"""
BOS-aligned dataloader with Best-Fit Cropping.
Reduces token waste compared to simple greedy cropping by searching a buffer
for documents that fit well, while maintaining 100% utilization (no padding).
Algorithm for each row:
1. From buffered docs, pick the LARGEST doc that fits entirely
2. Repeat until no doc fits
3. When nothing fits, crop a doc to fill remaining space exactly
Key properties:
- Every row starts with BOS
- 100% utilization (no padding, every token is trained on)
- Approximately 35% of all tokens are discarded due to cropping
"""
assert split in ["train", "val"], "split must be 'train' or 'val'"
row_capacity = T + 1
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
bos_token = tokenizer.get_bos_token_id()
doc_buffer = []
pq_idx, rg_idx, epoch = 0, 0, 1
def refill_buffer():
nonlocal pq_idx, rg_idx, epoch
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists:
doc_buffer.append(tokens)
# Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
# This gives us contiguous views and a single HtoD transfer
use_cuda = device == "cuda"
row_buffer = torch.empty((B, row_capacity), dtype=torch.long) # for building rows without creating Python lists
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU)
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer
cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience
cpu_targets = cpu_buffer[B * T:].view(B, T)
inputs = gpu_buffer[:B * T].view(B, T)
targets = gpu_buffer[B * T:].view(B, T)
while True:
for row_idx in range(B):
pos = 0
while pos < row_capacity:
# Ensure buffer has documents
while len(doc_buffer) < buffer_size:
refill_buffer()
remaining = row_capacity - pos
# Find largest doc that fits entirely
best_idx = -1
best_len = 0
for i, doc in enumerate(doc_buffer):
doc_len = len(doc)
if doc_len <= remaining and doc_len > best_len:
best_idx = i
best_len = doc_len
if best_idx >= 0:
doc = doc_buffer.pop(best_idx)
doc_len = len(doc)
row_buffer[row_idx, pos:pos + doc_len] = torch.tensor(doc, dtype=torch.long)
pos += doc_len
else:
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
doc = doc_buffer.pop(shortest_idx)
row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
pos += remaining
# Copy to pinned CPU buffer, then single HtoD transfer
cpu_inputs.copy_(row_buffer[:, :-1])
cpu_targets.copy_(row_buffer[:, 1:])
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
# Single HtoD copy into persistent GPU buffer and yield
gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda)
yield inputs, targets, state_dict
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
"""Helper that omits state_dict from yields."""
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs):
yield inputs, targets