mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-04 17:49:49 +00:00
Merge 941ce3cce8 into 1ddaad1c1c
This commit is contained in:
commit
445b431d7c
|
|
@ -1,24 +1,19 @@
|
|||
"""
|
||||
Distributed dataloaders for pretraining.
|
||||
|
||||
Two implementations are provided:
|
||||
|
||||
1. Original (tokenizing_distributed_data_loader):
|
||||
- Streams tokens into a flat buffer, reshapes to (B, T)
|
||||
- Rows may start mid-document (no guaranteed BOS at position 0)
|
||||
- 100% token utilization, simple and efficient
|
||||
|
||||
2. BOS-aligned bestfit (tokenizing_distributed_data_loader_bos_bestfit):
|
||||
BOS-aligned bestfit:
|
||||
- Every row starts with BOS token
|
||||
- Documents packed using best-fit algorithm to minimize cropping
|
||||
- When no document fits remaining space, crops a document to fill exactly
|
||||
- 100% utilization (no padding), ~35% tokens cropped at T=2048
|
||||
|
||||
The tradeoff: BOS-aligned loses ~35% of tokens to cropping, but ensures that
|
||||
Compared to the original tokenizing_distributed_data_loader:
|
||||
BOS-aligned loses ~35% of tokens to cropping, but ensures that
|
||||
there are fewer "confusing" tokens in the train/val batches as every token can
|
||||
now attend back to the BOS token and sees the full context of the document.
|
||||
(2) is the new default if you have enough data.
|
||||
Fallback to (1) if you have very limited data AND long documents.
|
||||
|
||||
Fallback to the original if you have very limited data AND long documents:
|
||||
https://github.com/karpathy/nanochat/blob/3c3a3d7/nanochat/dataloader.py#L78-L117
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
|
@ -75,48 +70,6 @@ def _document_batches(split, resume_state_dict, tokenizer_batch_size):
|
|||
epoch += 1
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader_with_state(tokenizer, B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
|
||||
"""
|
||||
Stream pretraining text from parquet files, tokenize, yield training batches.
|
||||
|
||||
This is the original dataloader that streams tokens into a flat buffer and reshapes.
|
||||
Rows may start mid-document (no guaranteed BOS at position 0).
|
||||
|
||||
Supports approximate resume via state_dict.
|
||||
"""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
|
||||
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
|
||||
needed_tokens = B * T + 1 # +1 for target at last position
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
token_buffer = []
|
||||
pq_idx, rg_idx, epoch = 0, 0, 1
|
||||
|
||||
while True:
|
||||
|
||||
# Accumulate enough tokens
|
||||
while len(token_buffer) < needed_tokens:
|
||||
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
|
||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
||||
for tokens in token_lists:
|
||||
token_buffer.extend(tokens)
|
||||
tokens = token_buffer[:needed_tokens] # Read B*T+1 tokens (+1 is only for the target for the last token)
|
||||
token_buffer = token_buffer[B*T:] # Advance by B*T tokens, so we move exactly one window of B*T tokens over
|
||||
|
||||
# Package tokens into inputs and targets, yield
|
||||
use_cuda = device == "cuda"
|
||||
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda)
|
||||
inputs = scratch[:-1].view(B, T).to(device=device, non_blocking=use_cuda)
|
||||
targets = scratch[1:].view(B, T).to(device=device, non_blocking=use_cuda)
|
||||
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader(*args, **kwargs):
|
||||
"""Helper that omits state_dict from yields."""
|
||||
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
|
||||
yield inputs, targets
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
||||
tokenizer, B, T, split,
|
||||
tokenizer_threads=4, tokenizer_batch_size=128,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user