mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-15 21:38:24 +00:00
Fix for garbage collection
This commit is contained in:
parent
35174d1725
commit
814475af42
|
|
@ -144,66 +144,92 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
|||
row_capacity = T + 1
|
||||
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
doc_buffer = []
|
||||
pq_idx, rg_idx, epoch = 0, 0, 1
|
||||
|
||||
# Token pool: single tensor holding all buffered tokens
|
||||
# Documents tracked as (start, length) tuples
|
||||
pool = torch.empty(buffer_size * 512, dtype=torch.long)
|
||||
pool_end = 0
|
||||
docs = [] # [(start, length), ...]
|
||||
|
||||
def compact_pool():
|
||||
"""Shift active documents to front of pool, reclaiming space."""
|
||||
nonlocal pool_end
|
||||
if not docs:
|
||||
pool_end = 0
|
||||
return
|
||||
write_pos = 0
|
||||
for i, (start, length) in enumerate(docs):
|
||||
if start != write_pos:
|
||||
pool[write_pos:write_pos + length] = pool[start:start + length].clone()
|
||||
docs[i] = (write_pos, length)
|
||||
write_pos += length
|
||||
pool_end = write_pos
|
||||
|
||||
def refill_buffer():
|
||||
nonlocal pq_idx, rg_idx, epoch
|
||||
"""Retrieve more docs and add them to the pool"""
|
||||
nonlocal pq_idx, rg_idx, epoch, pool, pool_end
|
||||
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
|
||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
||||
# Number of new tokens to store
|
||||
total_new = sum(len(t) for t in token_lists)
|
||||
# If there's not enough space at the end,
|
||||
if pool_end + total_new > pool.size(0):
|
||||
compact_pool() # Try compacting first.
|
||||
# If still not enough,
|
||||
if pool_end + total_new > pool.size(0):
|
||||
# Allocate a new, larger pool.
|
||||
new_size = max(pool.size(0) * 2, pool_end + total_new)
|
||||
new_pool = torch.empty(new_size, dtype=torch.long)
|
||||
new_pool[:pool_end] = pool[:pool_end]
|
||||
pool = new_pool
|
||||
# Write tokens to pool
|
||||
for tokens in token_lists:
|
||||
doc_buffer.append(tokens)
|
||||
n = len(tokens)
|
||||
pool[pool_end:pool_end + n] = torch.tensor(tokens, dtype=torch.long)
|
||||
docs.append((pool_end, n))
|
||||
pool_end += n
|
||||
|
||||
# Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
|
||||
# This gives us contiguous views and a single HtoD transfer
|
||||
# Pre-allocate buffers once
|
||||
use_cuda = device == "cuda"
|
||||
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU)
|
||||
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer
|
||||
cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience
|
||||
cpu_targets = cpu_buffer[B * T:].view(B, T)
|
||||
inputs = gpu_buffer[:B * T].view(B, T)
|
||||
targets = gpu_buffer[B * T:].view(B, T)
|
||||
row_buffer = torch.empty((B, row_capacity), dtype=torch.long)
|
||||
inputs = torch.empty((B, T), dtype=torch.long, device=device)
|
||||
targets = torch.empty((B, T), dtype=torch.long, device=device)
|
||||
|
||||
while True:
|
||||
rows = []
|
||||
for _ in range(B):
|
||||
row = []
|
||||
while len(row) < row_capacity:
|
||||
for row_idx in range(B):
|
||||
col = 0
|
||||
while col < row_capacity:
|
||||
# Ensure buffer has documents
|
||||
while len(doc_buffer) < buffer_size:
|
||||
while len(docs) < buffer_size:
|
||||
refill_buffer()
|
||||
|
||||
remaining = row_capacity - len(row)
|
||||
remaining = row_capacity - col
|
||||
|
||||
# Find largest doc that fits entirely
|
||||
best_idx = -1
|
||||
best_len = 0
|
||||
for i, doc in enumerate(doc_buffer):
|
||||
doc_len = len(doc)
|
||||
if doc_len <= remaining and doc_len > best_len:
|
||||
for i, (start, length) in enumerate(docs):
|
||||
if length <= remaining and length > best_len:
|
||||
best_idx = i
|
||||
best_len = doc_len
|
||||
best_len = length
|
||||
|
||||
if best_idx >= 0:
|
||||
doc = doc_buffer.pop(best_idx)
|
||||
row.extend(doc)
|
||||
start, length = docs.pop(best_idx)
|
||||
row_buffer[row_idx, col:col + length] = pool[start:start + length]
|
||||
col += length
|
||||
else:
|
||||
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
|
||||
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
|
||||
doc = doc_buffer.pop(shortest_idx)
|
||||
row.extend(doc[:remaining])
|
||||
# No doc fits - crop shortest to fill remaining
|
||||
shortest_idx = min(range(len(docs)), key=lambda i: docs[i][1])
|
||||
start, length = docs.pop(shortest_idx)
|
||||
row_buffer[row_idx, col:col + remaining] = pool[start:start + remaining]
|
||||
col += remaining
|
||||
|
||||
rows.append(row[:row_capacity])
|
||||
|
||||
# Convert rows to tensor and copy slices to pinned buffer (CPU work)
|
||||
row_data = torch.tensor(rows, dtype=torch.long) # [B, T+1], temporary
|
||||
cpu_inputs.copy_(row_data[:, :-1])
|
||||
cpu_targets.copy_(row_data[:, 1:])
|
||||
# Copy to GPU
|
||||
inputs.copy_(row_buffer[:, :-1], non_blocking=use_cuda)
|
||||
targets.copy_(row_buffer[:, 1:], non_blocking=use_cuda)
|
||||
|
||||
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
||||
|
||||
# Single HtoD copy into persistent GPU buffer and yield
|
||||
gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda)
|
||||
yield inputs, targets, state_dict
|
||||
|
||||
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user