mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-03 17:19:50 +00:00
fix dataloader for midtrain to never crop data. we can't just throw it away like we do in pretraining
This commit is contained in:
parent
3c3a3d7042
commit
348fbb301b
|
|
@ -125,11 +125,12 @@ approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
|
||||||
current_epoch = 1 # track epoch for logging
|
current_epoch = 1 # track epoch for logging
|
||||||
def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
||||||
"""
|
"""
|
||||||
BOS-aligned dataloader for midtraining with bestfit-crop packing.
|
BOS-aligned dataloader for midtraining with bestfit-pad packing.
|
||||||
|
|
||||||
Each row in the batch starts with BOS (beginning of a conversation).
|
Each row in the batch starts with BOS (beginning of a conversation).
|
||||||
Conversations are packed using best-fit algorithm to minimize cropping.
|
Conversations are packed using best-fit algorithm. When no conversation fits,
|
||||||
This matches the BOS-aligned approach used in pretraining.
|
the row is padded (instead of cropping) to ensure no tokens are ever discarded.
|
||||||
|
Padding positions have targets masked with -1 (ignore_index for cross-entropy).
|
||||||
"""
|
"""
|
||||||
global last_step, approx_progress, current_epoch
|
global last_step, approx_progress, current_epoch
|
||||||
assert split in {"train", "val"}, "split must be 'train' or 'val'"
|
assert split in {"train", "val"}, "split must be 'train' or 'val'"
|
||||||
|
|
@ -137,6 +138,7 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
||||||
dataset_size = len(dataset)
|
dataset_size = len(dataset)
|
||||||
assert dataset_size > 0
|
assert dataset_size > 0
|
||||||
row_capacity = args.max_seq_len + 1 # +1 for target at last position
|
row_capacity = args.max_seq_len + 1 # +1 for target at last position
|
||||||
|
bos_token = tokenizer.get_bos_token_id()
|
||||||
|
|
||||||
# Conversation buffer: list of token lists
|
# Conversation buffer: list of token lists
|
||||||
conv_buffer = []
|
conv_buffer = []
|
||||||
|
|
@ -159,8 +161,10 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
rows = []
|
rows = []
|
||||||
|
row_lengths = [] # Track actual content length (excluding padding) for each row
|
||||||
for _ in range(args.device_batch_size):
|
for _ in range(args.device_batch_size):
|
||||||
row = []
|
row = []
|
||||||
|
padded = False
|
||||||
while len(row) < row_capacity:
|
while len(row) < row_capacity:
|
||||||
# Ensure buffer has conversations
|
# Ensure buffer has conversations
|
||||||
while len(conv_buffer) < buffer_size:
|
while len(conv_buffer) < buffer_size:
|
||||||
|
|
@ -183,11 +187,18 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
||||||
row.extend(conv)
|
row.extend(conv)
|
||||||
consumed += ddp_world_size # Track actual consumption
|
consumed += ddp_world_size # Track actual consumption
|
||||||
else:
|
else:
|
||||||
# No conversation fits - crop first conversation to fill remaining
|
# No conversation fits - pad the remainder instead of cropping
|
||||||
conv = conv_buffer.pop(0)
|
# This ensures we never discard any tokens
|
||||||
row.extend(conv[:remaining])
|
content_len = len(row)
|
||||||
consumed += ddp_world_size # Track actual consumption
|
row.extend([bos_token] * remaining) # Pad with BOS tokens
|
||||||
|
padded = True
|
||||||
|
break # Row is now full (with padding)
|
||||||
|
|
||||||
|
# Track content length: full row if no padding, otherwise the length before padding
|
||||||
|
if padded:
|
||||||
|
row_lengths.append(content_len)
|
||||||
|
else:
|
||||||
|
row_lengths.append(row_capacity)
|
||||||
rows.append(row[:row_capacity])
|
rows.append(row[:row_capacity])
|
||||||
|
|
||||||
# Stopping condition to respect num_iterations, if given
|
# Stopping condition to respect num_iterations, if given
|
||||||
|
|
@ -212,6 +223,12 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
||||||
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
|
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
|
||||||
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
|
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
|
||||||
|
|
||||||
|
# Mask out padding positions in targets (set to -1 = ignore_index)
|
||||||
|
# For each row, positions >= (content_length - 1) in targets should be masked
|
||||||
|
for i, content_len in enumerate(row_lengths):
|
||||||
|
if content_len < row_capacity:
|
||||||
|
targets[i, content_len-1:] = -1
|
||||||
|
|
||||||
yield inputs, targets
|
yield inputs, targets
|
||||||
|
|
||||||
train_loader = mid_data_generator_bos_bestfit("train")
|
train_loader = mid_data_generator_bos_bestfit("train")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user