slightly more efficient dataloader that reduces the number of python objects flying around and causing strain on runtime and garbage collector

This commit is contained in:
Andrej Karpathy 2026-02-02 01:17:30 +00:00
parent 8b4849d548
commit e8fec97d4c

View File

@ -110,6 +110,7 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
# Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
# This gives us contiguous views and a single HtoD transfer
use_cuda = device == "cuda"
row_buffer = torch.empty((B, row_capacity), dtype=torch.long) # for building rows without creating Python lists
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU)
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer
cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience
@ -118,15 +119,14 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
targets = gpu_buffer[B * T:].view(B, T)
while True:
rows = []
for _ in range(B):
row = []
while len(row) < row_capacity:
for row_idx in range(B):
pos = 0
while pos < row_capacity:
# Ensure buffer has documents
while len(doc_buffer) < buffer_size:
refill_buffer()
remaining = row_capacity - len(row)
remaining = row_capacity - pos
# Find largest doc that fits entirely
best_idx = -1
@ -139,19 +139,19 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
if best_idx >= 0:
doc = doc_buffer.pop(best_idx)
row.extend(doc)
doc_len = len(doc)
row_buffer[row_idx, pos:pos + doc_len] = torch.tensor(doc, dtype=torch.long)
pos += doc_len
else:
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
doc = doc_buffer.pop(shortest_idx)
row.extend(doc[:remaining])
row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
pos += remaining
rows.append(row[:row_capacity])
# Convert rows to tensor and copy slices to pinned buffer (CPU work)
row_data = torch.tensor(rows, dtype=torch.long) # [B, T+1], temporary
cpu_inputs.copy_(row_data[:, :-1])
cpu_targets.copy_(row_data[:, 1:])
# Copy to pinned CPU buffer, then single HtoD transfer
cpu_inputs.copy_(row_buffer[:, :-1])
cpu_targets.copy_(row_buffer[:, 1:])
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}