This commit is contained in:
Dylan 2026-03-28 02:03:28 +00:00 committed by GitHub
commit 115a6eb8dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -190,7 +190,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
Each row in the batch starts with BOS (beginning of a conversation).
Conversations are packed using best-fit algorithm. When no conversation fits,
the row is padded (instead of cropping) to ensure no tokens are ever discarded.
the row is padded. If the row is still empty, we fall back to truncating one
conversation to avoid fully padded rows.
Padding positions have targets masked with -1 (ignore_index for cross-entropy).
"""
global last_step, approx_progress, current_epoch
@ -251,13 +252,27 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
mask_row.extend(conv_mask)
consumed += ddp_world_size # Track actual consumption
else:
# No conversation fits - pad the remainder instead of cropping
# This ensures we never discard any tokens
content_len = len(row)
row.extend([bos_token] * remaining) # Pad with BOS tokens
mask_row.extend([0] * remaining)
padded = True
break # Row is now full (with padding)
if len(row) == 0:
# No conversation fits even on an empty row.
# Truncate the shortest oversized conversation to avoid fully padded rows.
trunc_idx = -1
trunc_len = None
for i, (conv, _) in enumerate(conv_buffer):
conv_len = len(conv)
if conv_len > remaining and (trunc_len is None or conv_len < trunc_len):
trunc_idx = i
trunc_len = conv_len
conv, conv_mask = conv_buffer.pop(trunc_idx)
row.extend(conv[:row_capacity])
mask_row.extend(conv_mask[:row_capacity])
consumed += ddp_world_size
else:
# No conversation fits - pad the remainder instead of dropping tokens.
content_len = len(row)
row.extend([bos_token] * remaining)
mask_row.extend([0] * remaining)
padded = True
break # Row is now full (with padding or truncation)
# Track content length: full row if no padding, otherwise the length before padding
if padded: