Fix for garbage collection

This commit is contained in:
Chris McCormick 2026-01-31 00:33:16 -08:00
parent 35174d1725
commit 814475af42

View File

@ -144,66 +144,92 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
row_capacity = T + 1
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
bos_token = tokenizer.get_bos_token_id()
doc_buffer = []
pq_idx, rg_idx, epoch = 0, 0, 1
# Token pool: single tensor holding all buffered tokens
# Documents tracked as (start, length) tuples
pool = torch.empty(buffer_size * 512, dtype=torch.long)
pool_end = 0
docs = [] # [(start, length), ...]
def compact_pool():
"""Shift active documents to front of pool, reclaiming space."""
nonlocal pool_end
if not docs:
pool_end = 0
return
write_pos = 0
for i, (start, length) in enumerate(docs):
if start != write_pos:
pool[write_pos:write_pos + length] = pool[start:start + length].clone()
docs[i] = (write_pos, length)
write_pos += length
pool_end = write_pos
def refill_buffer():
nonlocal pq_idx, rg_idx, epoch
"""Retrieve more docs and add them to the pool"""
nonlocal pq_idx, rg_idx, epoch, pool, pool_end
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
# Number of new tokens to store
total_new = sum(len(t) for t in token_lists)
# If there's not enough space at the end,
if pool_end + total_new > pool.size(0):
compact_pool() # Try compacting first.
# If still not enough,
if pool_end + total_new > pool.size(0):
# Allocate a new, larger pool.
new_size = max(pool.size(0) * 2, pool_end + total_new)
new_pool = torch.empty(new_size, dtype=torch.long)
new_pool[:pool_end] = pool[:pool_end]
pool = new_pool
# Write tokens to pool
for tokens in token_lists:
doc_buffer.append(tokens)
n = len(tokens)
pool[pool_end:pool_end + n] = torch.tensor(tokens, dtype=torch.long)
docs.append((pool_end, n))
pool_end += n
# Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
# This gives us contiguous views and a single HtoD transfer
# Pre-allocate buffers once
use_cuda = device == "cuda"
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU)
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer
cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience
cpu_targets = cpu_buffer[B * T:].view(B, T)
inputs = gpu_buffer[:B * T].view(B, T)
targets = gpu_buffer[B * T:].view(B, T)
row_buffer = torch.empty((B, row_capacity), dtype=torch.long)
inputs = torch.empty((B, T), dtype=torch.long, device=device)
targets = torch.empty((B, T), dtype=torch.long, device=device)
while True:
rows = []
for _ in range(B):
row = []
while len(row) < row_capacity:
for row_idx in range(B):
col = 0
while col < row_capacity:
# Ensure buffer has documents
while len(doc_buffer) < buffer_size:
while len(docs) < buffer_size:
refill_buffer()
remaining = row_capacity - len(row)
remaining = row_capacity - col
# Find largest doc that fits entirely
best_idx = -1
best_len = 0
for i, doc in enumerate(doc_buffer):
doc_len = len(doc)
if doc_len <= remaining and doc_len > best_len:
for i, (start, length) in enumerate(docs):
if length <= remaining and length > best_len:
best_idx = i
best_len = doc_len
best_len = length
if best_idx >= 0:
doc = doc_buffer.pop(best_idx)
row.extend(doc)
start, length = docs.pop(best_idx)
row_buffer[row_idx, col:col + length] = pool[start:start + length]
col += length
else:
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
doc = doc_buffer.pop(shortest_idx)
row.extend(doc[:remaining])
# No doc fits - crop shortest to fill remaining
shortest_idx = min(range(len(docs)), key=lambda i: docs[i][1])
start, length = docs.pop(shortest_idx)
row_buffer[row_idx, col:col + remaining] = pool[start:start + remaining]
col += remaining
rows.append(row[:row_capacity])
# Convert rows to tensor and copy slices to pinned buffer (CPU work)
row_data = torch.tensor(rows, dtype=torch.long) # [B, T+1], temporary
cpu_inputs.copy_(row_data[:, :-1])
cpu_targets.copy_(row_data[:, 1:])
# Copy to GPU
inputs.copy_(row_buffer[:, :-1], non_blocking=use_cuda)
targets.copy_(row_buffer[:, 1:], non_blocking=use_cuda)
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
# Single HtoD copy into persistent GPU buffer and yield
gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda)
yield inputs, targets, state_dict
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):