mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-02 23:40:36 +00:00
slightly more efficient dataloader that reduces the number of python objects flying around and causing strain on runtime and garbage collector
This commit is contained in:
parent
8b4849d548
commit
e8fec97d4c
|
|
@ -110,6 +110,7 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
|||
# Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
|
||||
# This gives us contiguous views and a single HtoD transfer
|
||||
use_cuda = device == "cuda"
|
||||
row_buffer = torch.empty((B, row_capacity), dtype=torch.long) # for building rows without creating Python lists
|
||||
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU)
|
||||
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer
|
||||
cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience
|
||||
|
|
@ -118,15 +119,14 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
|||
targets = gpu_buffer[B * T:].view(B, T)
|
||||
|
||||
while True:
|
||||
rows = []
|
||||
for _ in range(B):
|
||||
row = []
|
||||
while len(row) < row_capacity:
|
||||
for row_idx in range(B):
|
||||
pos = 0
|
||||
while pos < row_capacity:
|
||||
# Ensure buffer has documents
|
||||
while len(doc_buffer) < buffer_size:
|
||||
refill_buffer()
|
||||
|
||||
remaining = row_capacity - len(row)
|
||||
remaining = row_capacity - pos
|
||||
|
||||
# Find largest doc that fits entirely
|
||||
best_idx = -1
|
||||
|
|
@ -139,19 +139,19 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
|||
|
||||
if best_idx >= 0:
|
||||
doc = doc_buffer.pop(best_idx)
|
||||
row.extend(doc)
|
||||
doc_len = len(doc)
|
||||
row_buffer[row_idx, pos:pos + doc_len] = torch.tensor(doc, dtype=torch.long)
|
||||
pos += doc_len
|
||||
else:
|
||||
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
|
||||
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
|
||||
doc = doc_buffer.pop(shortest_idx)
|
||||
row.extend(doc[:remaining])
|
||||
row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
|
||||
pos += remaining
|
||||
|
||||
rows.append(row[:row_capacity])
|
||||
|
||||
# Convert rows to tensor and copy slices to pinned buffer (CPU work)
|
||||
row_data = torch.tensor(rows, dtype=torch.long) # [B, T+1], temporary
|
||||
cpu_inputs.copy_(row_data[:, :-1])
|
||||
cpu_targets.copy_(row_data[:, 1:])
|
||||
# Copy to pinned CPU buffer, then single HtoD transfer
|
||||
cpu_inputs.copy_(row_buffer[:, :-1])
|
||||
cpu_targets.copy_(row_buffer[:, 1:])
|
||||
|
||||
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user