This commit is contained in:
Dylan 2026-03-15 02:57:40 +08:00 committed by GitHub
commit 087d48bd61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -190,7 +190,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
Each row in the batch starts with BOS (beginning of a conversation). Each row in the batch starts with BOS (beginning of a conversation).
Conversations are packed using best-fit algorithm. When no conversation fits, Conversations are packed using best-fit algorithm. When no conversation fits,
the row is padded (instead of cropping) to ensure no tokens are ever discarded. the row is padded. If the row is still empty, we fall back to truncating one
conversation to avoid fully padded rows.
Padding positions have targets masked with -1 (ignore_index for cross-entropy). Padding positions have targets masked with -1 (ignore_index for cross-entropy).
""" """
global last_step, approx_progress, current_epoch global last_step, approx_progress, current_epoch
@ -251,13 +252,27 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
mask_row.extend(conv_mask) mask_row.extend(conv_mask)
consumed += ddp_world_size # Track actual consumption consumed += ddp_world_size # Track actual consumption
else: else:
# No conversation fits - pad the remainder instead of cropping if len(row) == 0:
# This ensures we never discard any tokens # No conversation fits even on an empty row.
content_len = len(row) # Truncate the shortest oversized conversation to avoid fully padded rows.
row.extend([bos_token] * remaining) # Pad with BOS tokens trunc_idx = -1
mask_row.extend([0] * remaining) trunc_len = None
padded = True for i, (conv, _) in enumerate(conv_buffer):
break # Row is now full (with padding) conv_len = len(conv)
if conv_len > remaining and (trunc_len is None or conv_len < trunc_len):
trunc_idx = i
trunc_len = conv_len
conv, conv_mask = conv_buffer.pop(trunc_idx)
row.extend(conv[:row_capacity])
mask_row.extend(conv_mask[:row_capacity])
consumed += ddp_world_size
else:
# No conversation fits - pad the remainder instead of dropping tokens.
content_len = len(row)
row.extend([bos_token] * remaining)
mask_row.extend([0] * remaining)
padded = True
break # Row is now full (with padding or truncation)
# Track content length: full row if no padding, otherwise the length before padding # Track content length: full row if no padding, otherwise the length before padding
if padded: if padded: