diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index b46dd81..db7a079 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -190,7 +190,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): Each row in the batch starts with BOS (beginning of a conversation). Conversations are packed using best-fit algorithm. When no conversation fits, - the row is padded (instead of cropping) to ensure no tokens are ever discarded. + the row is padded. If the row is still empty, we fall back to truncating one + conversation to avoid fully padded rows. Padding positions have targets masked with -1 (ignore_index for cross-entropy). """ global last_step, approx_progress, current_epoch @@ -251,13 +252,27 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): mask_row.extend(conv_mask) consumed += ddp_world_size # Track actual consumption else: - # No conversation fits - pad the remainder instead of cropping - # This ensures we never discard any tokens - content_len = len(row) - row.extend([bos_token] * remaining) # Pad with BOS tokens - mask_row.extend([0] * remaining) - padded = True - break # Row is now full (with padding) + if len(row) == 0: + # No conversation fits even on an empty row. + # Truncate the shortest oversized conversation to avoid fully padded rows. + trunc_idx = -1 + trunc_len = None + for i, (conv, _) in enumerate(conv_buffer): + conv_len = len(conv) + if conv_len > remaining and (trunc_len is None or conv_len < trunc_len): + trunc_idx = i + trunc_len = conv_len + conv, conv_mask = conv_buffer.pop(trunc_idx) + row.extend(conv[:row_capacity]) + mask_row.extend(conv_mask[:row_capacity]) + consumed += ddp_world_size + else: + # No conversation fits - pad the remainder instead of dropping tokens. + content_len = len(row) + row.extend([bos_token] * remaining) + mask_row.extend([0] * remaining) + padded = True + break # Row is now full (with padding or truncation) # Track content length: full row if no padding, otherwise the length before padding if padded: