fix dataloader for midtrain to never crop data. we can't just throw it away like we do in pretraining

This commit is contained in:
Andrej Karpathy 2026-01-31 18:21:36 +00:00
parent 3c3a3d7042
commit 348fbb301b

View File

@ -125,11 +125,12 @@ approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
current_epoch = 1 # track epoch for logging
def mid_data_generator_bos_bestfit(split, buffer_size=100):
"""
BOS-aligned dataloader for midtraining with bestfit-crop packing.
BOS-aligned dataloader for midtraining with bestfit-pad packing.
Each row in the batch starts with BOS (beginning of a conversation).
Conversations are packed using best-fit algorithm to minimize cropping.
This matches the BOS-aligned approach used in pretraining.
Conversations are packed using best-fit algorithm. When no conversation fits,
the row is padded (instead of cropping) to ensure no tokens are ever discarded.
Padding positions have targets masked with -1 (ignore_index for cross-entropy).
"""
global last_step, approx_progress, current_epoch
assert split in {"train", "val"}, "split must be 'train' or 'val'"
@ -137,6 +138,7 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
dataset_size = len(dataset)
assert dataset_size > 0
row_capacity = args.max_seq_len + 1 # +1 for target at last position
bos_token = tokenizer.get_bos_token_id()
# Conversation buffer: list of token lists
conv_buffer = []
@ -159,8 +161,10 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
while True:
rows = []
row_lengths = [] # Track actual content length (excluding padding) for each row
for _ in range(args.device_batch_size):
row = []
padded = False
while len(row) < row_capacity:
# Ensure buffer has conversations
while len(conv_buffer) < buffer_size:
@ -183,11 +187,18 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
row.extend(conv)
consumed += ddp_world_size # Track actual consumption
else:
# No conversation fits - crop first conversation to fill remaining
conv = conv_buffer.pop(0)
row.extend(conv[:remaining])
consumed += ddp_world_size # Track actual consumption
# No conversation fits - pad the remainder instead of cropping
# This ensures we never discard any tokens
content_len = len(row)
row.extend([bos_token] * remaining) # Pad with BOS tokens
padded = True
break # Row is now full (with padding)
# Track content length: full row if no padding, otherwise the length before padding
if padded:
row_lengths.append(content_len)
else:
row_lengths.append(row_capacity)
rows.append(row[:row_capacity])
# Stopping condition to respect num_iterations, if given
@ -212,6 +223,12 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
# Mask out padding positions in targets (set to -1 = ignore_index)
# For each row, positions >= (content_length - 1) in targets should be masked
for i, content_len in enumerate(row_lengths):
if content_len < row_capacity:
targets[i, content_len-1:] = -1
yield inputs, targets
train_loader = mid_data_generator_bos_bestfit("train")