avoid fully padded rows by truncating shortest convo

This commit is contained in:
Dylan Chen 2026-02-04 12:16:01 +08:00
parent d4db003661
commit 7c0eb3f00b

View File

@ -130,7 +130,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
Each row in the batch starts with BOS (beginning of a conversation).
Conversations are packed using best-fit algorithm. When no conversation fits,
the row is padded (instead of cropping) to ensure no tokens are ever discarded.
the row is padded. If the row is still empty, we fall back to truncating one
conversation to avoid fully padded rows.
Padding positions have targets masked with -1 (ignore_index for cross-entropy).
"""
global last_step, approx_progress, current_epoch
@ -152,7 +153,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
nonlocal cursor, epoch
while len(conv_buffer) < buffer_size:
conversation = dataset[cursor]
ids, _ = tokenizer.render_conversation(conversation, max_tokens=row_capacity)
ids, _ = tokenizer.render_conversation(conversation)
conv_buffer.append(ids)
cursor += ddp_world_size
if cursor >= dataset_size:
@ -188,12 +189,25 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
row.extend(conv)
consumed += ddp_world_size # Track actual consumption
else:
# No conversation fits - pad the remainder instead of cropping
# This ensures we never discard any tokens
content_len = len(row)
row.extend([bos_token] * remaining) # Pad with BOS tokens
padded = True
break # Row is now full (with padding)
if len(row) == 0:
# No conversation fits even on an empty row.
# Truncate the shortest one to avoid fully padded rows.
trunc_idx = -1
trunc_len = None
for i, conv in enumerate(conv_buffer):
conv_len = len(conv)
if conv_len > remaining and (trunc_len is None or conv_len < trunc_len):
trunc_idx = i
trunc_len = conv_len
conv = conv_buffer.pop(trunc_idx)
row.extend(conv[:row_capacity])
consumed += ddp_world_size
else:
# No conversation fits - pad the remainder
content_len = len(row)
row.extend([bos_token] * remaining) # Pad with BOS tokens
padded = True
break # Row is now full (with padding or truncation)
# Track content length: full row if no padding, otherwise the length before padding
if padded: