From e8fec97d4c6554b0c898a6c5c747a0496fe9b761 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 2 Feb 2026 01:17:30 +0000 Subject: [PATCH] slightly more efficient dataloader that reduces the number of python objects flying around and causing strain on runtime and garbage collector --- nanochat/dataloader.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 1cbdef7..125625f 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -110,6 +110,7 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit( # Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)] # This gives us contiguous views and a single HtoD transfer use_cuda = device == "cuda" + row_buffer = torch.empty((B, row_capacity), dtype=torch.long) # for building rows without creating Python lists cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU) gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience @@ -118,15 +119,14 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit( targets = gpu_buffer[B * T:].view(B, T) while True: - rows = [] - for _ in range(B): - row = [] - while len(row) < row_capacity: + for row_idx in range(B): + pos = 0 + while pos < row_capacity: # Ensure buffer has documents while len(doc_buffer) < buffer_size: refill_buffer() - remaining = row_capacity - len(row) + remaining = row_capacity - pos # Find largest doc that fits entirely best_idx = -1 @@ -139,19 +139,19 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit( if best_idx >= 0: doc = doc_buffer.pop(best_idx) - row.extend(doc) + doc_len = len(doc) + row_buffer[row_idx, pos:pos + doc_len] = torch.tensor(doc, dtype=torch.long) + pos += doc_len else: # No doc fits - crop shortest in buffer to fill remaining and minimize waste shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i])) doc = doc_buffer.pop(shortest_idx) - row.extend(doc[:remaining]) + row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long) + pos += remaining - rows.append(row[:row_capacity]) - - # Convert rows to tensor and copy slices to pinned buffer (CPU work) - row_data = torch.tensor(rows, dtype=torch.long) # [B, T+1], temporary - cpu_inputs.copy_(row_data[:, :-1]) - cpu_targets.copy_(row_data[:, 1:]) + # Copy to pinned CPU buffer, then single HtoD transfer + cpu_inputs.copy_(row_buffer[:, :-1]) + cpu_targets.copy_(row_buffer[:, 1:]) state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}