mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-03 09:09:49 +00:00
fix dataloader for midtrain to never crop data. we can't just throw it away like we do in pretraining
This commit is contained in:
parent
3c3a3d7042
commit
348fbb301b
|
|
@ -125,11 +125,12 @@ approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
|
|||
current_epoch = 1 # track epoch for logging
|
||||
def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
||||
"""
|
||||
BOS-aligned dataloader for midtraining with bestfit-crop packing.
|
||||
BOS-aligned dataloader for midtraining with bestfit-pad packing.
|
||||
|
||||
Each row in the batch starts with BOS (beginning of a conversation).
|
||||
Conversations are packed using best-fit algorithm to minimize cropping.
|
||||
This matches the BOS-aligned approach used in pretraining.
|
||||
Conversations are packed using best-fit algorithm. When no conversation fits,
|
||||
the row is padded (instead of cropping) to ensure no tokens are ever discarded.
|
||||
Padding positions have targets masked with -1 (ignore_index for cross-entropy).
|
||||
"""
|
||||
global last_step, approx_progress, current_epoch
|
||||
assert split in {"train", "val"}, "split must be 'train' or 'val'"
|
||||
|
|
@ -137,6 +138,7 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
dataset_size = len(dataset)
|
||||
assert dataset_size > 0
|
||||
row_capacity = args.max_seq_len + 1 # +1 for target at last position
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
|
||||
# Conversation buffer: list of token lists
|
||||
conv_buffer = []
|
||||
|
|
@ -159,8 +161,10 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
|
||||
while True:
|
||||
rows = []
|
||||
row_lengths = [] # Track actual content length (excluding padding) for each row
|
||||
for _ in range(args.device_batch_size):
|
||||
row = []
|
||||
padded = False
|
||||
while len(row) < row_capacity:
|
||||
# Ensure buffer has conversations
|
||||
while len(conv_buffer) < buffer_size:
|
||||
|
|
@ -183,11 +187,18 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
row.extend(conv)
|
||||
consumed += ddp_world_size # Track actual consumption
|
||||
else:
|
||||
# No conversation fits - crop first conversation to fill remaining
|
||||
conv = conv_buffer.pop(0)
|
||||
row.extend(conv[:remaining])
|
||||
consumed += ddp_world_size # Track actual consumption
|
||||
# No conversation fits - pad the remainder instead of cropping
|
||||
# This ensures we never discard any tokens
|
||||
content_len = len(row)
|
||||
row.extend([bos_token] * remaining) # Pad with BOS tokens
|
||||
padded = True
|
||||
break # Row is now full (with padding)
|
||||
|
||||
# Track content length: full row if no padding, otherwise the length before padding
|
||||
if padded:
|
||||
row_lengths.append(content_len)
|
||||
else:
|
||||
row_lengths.append(row_capacity)
|
||||
rows.append(row[:row_capacity])
|
||||
|
||||
# Stopping condition to respect num_iterations, if given
|
||||
|
|
@ -212,6 +223,12 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
|
||||
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
|
||||
|
||||
# Mask out padding positions in targets (set to -1 = ignore_index)
|
||||
# For each row, positions >= (content_length - 1) in targets should be masked
|
||||
for i, content_len in enumerate(row_lengths):
|
||||
if content_len < row_capacity:
|
||||
targets[i, content_len-1:] = -1
|
||||
|
||||
yield inputs, targets
|
||||
|
||||
train_loader = mid_data_generator_bos_bestfit("train")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user