From 348fbb301b8b709ad5d59bdf69e99a51982f594a Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sat, 31 Jan 2026 18:21:36 +0000 Subject: [PATCH] fix dataloader for midtrain to never crop data. we can't just throw it away like we do in pretraining --- scripts/mid_train.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/scripts/mid_train.py b/scripts/mid_train.py index ebe9cd5..54c5fb0 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -125,11 +125,12 @@ approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch current_epoch = 1 # track epoch for logging def mid_data_generator_bos_bestfit(split, buffer_size=100): """ - BOS-aligned dataloader for midtraining with bestfit-crop packing. + BOS-aligned dataloader for midtraining with bestfit-pad packing. Each row in the batch starts with BOS (beginning of a conversation). - Conversations are packed using best-fit algorithm to minimize cropping. - This matches the BOS-aligned approach used in pretraining. + Conversations are packed using best-fit algorithm. When no conversation fits, + the row is padded (instead of cropping) to ensure no tokens are ever discarded. + Padding positions have targets masked with -1 (ignore_index for cross-entropy). """ global last_step, approx_progress, current_epoch assert split in {"train", "val"}, "split must be 'train' or 'val'" @@ -137,6 +138,7 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100): dataset_size = len(dataset) assert dataset_size > 0 row_capacity = args.max_seq_len + 1 # +1 for target at last position + bos_token = tokenizer.get_bos_token_id() # Conversation buffer: list of token lists conv_buffer = [] @@ -159,8 +161,10 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100): while True: rows = [] + row_lengths = [] # Track actual content length (excluding padding) for each row for _ in range(args.device_batch_size): row = [] + padded = False while len(row) < row_capacity: # Ensure buffer has conversations while len(conv_buffer) < buffer_size: @@ -183,11 +187,18 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100): row.extend(conv) consumed += ddp_world_size # Track actual consumption else: - # No conversation fits - crop first conversation to fill remaining - conv = conv_buffer.pop(0) - row.extend(conv[:remaining]) - consumed += ddp_world_size # Track actual consumption + # No conversation fits - pad the remainder instead of cropping + # This ensures we never discard any tokens + content_len = len(row) + row.extend([bos_token] * remaining) # Pad with BOS tokens + padded = True + break # Row is now full (with padding) + # Track content length: full row if no padding, otherwise the length before padding + if padded: + row_lengths.append(content_len) + else: + row_lengths.append(row_capacity) rows.append(row[:row_capacity]) # Stopping condition to respect num_iterations, if given @@ -212,6 +223,12 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100): inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda) targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda) + # Mask out padding positions in targets (set to -1 = ignore_index) + # For each row, positions >= (content_length - 1) in targets should be masked + for i, content_len in enumerate(row_lengths): + if content_len < row_capacity: + targets[i, content_len-1:] = -1 + yield inputs, targets train_loader = mid_data_generator_bos_bestfit("train")