diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index e95c3af..7086038 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -144,66 +144,92 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit( row_capacity = T + 1 batches = _document_batches(split, resume_state_dict, tokenizer_batch_size) bos_token = tokenizer.get_bos_token_id() - doc_buffer = [] pq_idx, rg_idx, epoch = 0, 0, 1 + # Token pool: single tensor holding all buffered tokens + # Documents tracked as (start, length) tuples + pool = torch.empty(buffer_size * 512, dtype=torch.long) + pool_end = 0 + docs = [] # [(start, length), ...] + + def compact_pool(): + """Shift active documents to front of pool, reclaiming space.""" + nonlocal pool_end + if not docs: + pool_end = 0 + return + write_pos = 0 + for i, (start, length) in enumerate(docs): + if start != write_pos: + pool[write_pos:write_pos + length] = pool[start:start + length].clone() + docs[i] = (write_pos, length) + write_pos += length + pool_end = write_pos + def refill_buffer(): - nonlocal pq_idx, rg_idx, epoch + """Retrieve more docs and add them to the pool""" + nonlocal pq_idx, rg_idx, epoch, pool, pool_end doc_batch, (pq_idx, rg_idx, epoch) = next(batches) token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads) + # Number of new tokens to store + total_new = sum(len(t) for t in token_lists) + # If there's not enough space at the end, + if pool_end + total_new > pool.size(0): + compact_pool() # Try compacting first. + # If still not enough, + if pool_end + total_new > pool.size(0): + # Allocate a new, larger pool. + new_size = max(pool.size(0) * 2, pool_end + total_new) + new_pool = torch.empty(new_size, dtype=torch.long) + new_pool[:pool_end] = pool[:pool_end] + pool = new_pool + # Write tokens to pool for tokens in token_lists: - doc_buffer.append(tokens) + n = len(tokens) + pool[pool_end:pool_end + n] = torch.tensor(tokens, dtype=torch.long) + docs.append((pool_end, n)) + pool_end += n - # Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)] - # This gives us contiguous views and a single HtoD transfer + # Pre-allocate buffers once use_cuda = device == "cuda" - cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU) - gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer - cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience - cpu_targets = cpu_buffer[B * T:].view(B, T) - inputs = gpu_buffer[:B * T].view(B, T) - targets = gpu_buffer[B * T:].view(B, T) + row_buffer = torch.empty((B, row_capacity), dtype=torch.long) + inputs = torch.empty((B, T), dtype=torch.long, device=device) + targets = torch.empty((B, T), dtype=torch.long, device=device) while True: - rows = [] - for _ in range(B): - row = [] - while len(row) < row_capacity: + for row_idx in range(B): + col = 0 + while col < row_capacity: # Ensure buffer has documents - while len(doc_buffer) < buffer_size: + while len(docs) < buffer_size: refill_buffer() - remaining = row_capacity - len(row) + remaining = row_capacity - col # Find largest doc that fits entirely best_idx = -1 best_len = 0 - for i, doc in enumerate(doc_buffer): - doc_len = len(doc) - if doc_len <= remaining and doc_len > best_len: + for i, (start, length) in enumerate(docs): + if length <= remaining and length > best_len: best_idx = i - best_len = doc_len + best_len = length if best_idx >= 0: - doc = doc_buffer.pop(best_idx) - row.extend(doc) + start, length = docs.pop(best_idx) + row_buffer[row_idx, col:col + length] = pool[start:start + length] + col += length else: - # No doc fits - crop shortest in buffer to fill remaining and minimize waste - shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i])) - doc = doc_buffer.pop(shortest_idx) - row.extend(doc[:remaining]) + # No doc fits - crop shortest to fill remaining + shortest_idx = min(range(len(docs)), key=lambda i: docs[i][1]) + start, length = docs.pop(shortest_idx) + row_buffer[row_idx, col:col + remaining] = pool[start:start + remaining] + col += remaining - rows.append(row[:row_capacity]) - - # Convert rows to tensor and copy slices to pinned buffer (CPU work) - row_data = torch.tensor(rows, dtype=torch.long) # [B, T+1], temporary - cpu_inputs.copy_(row_data[:, :-1]) - cpu_targets.copy_(row_data[:, 1:]) + # Copy to GPU + inputs.copy_(row_buffer[:, :-1], non_blocking=use_cuda) + targets.copy_(row_buffer[:, 1:], non_blocking=use_cuda) state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch} - - # Single HtoD copy into persistent GPU buffer and yield - gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda) yield inputs, targets, state_dict def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):