diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index fb0c061..7a18ad3 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -130,7 +130,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): Each row in the batch starts with BOS (beginning of a conversation). Conversations are packed using best-fit algorithm. When no conversation fits, - the row is padded (instead of cropping) to ensure no tokens are ever discarded. + the row is padded. If the row is still empty, we fall back to truncating one + conversation to avoid fully padded rows. Padding positions have targets masked with -1 (ignore_index for cross-entropy). """ global last_step, approx_progress, current_epoch @@ -152,7 +153,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): nonlocal cursor, epoch while len(conv_buffer) < buffer_size: conversation = dataset[cursor] - ids, _ = tokenizer.render_conversation(conversation, max_tokens=row_capacity) + ids, _ = tokenizer.render_conversation(conversation) conv_buffer.append(ids) cursor += ddp_world_size if cursor >= dataset_size: @@ -188,12 +189,25 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): row.extend(conv) consumed += ddp_world_size # Track actual consumption else: - # No conversation fits - pad the remainder instead of cropping - # This ensures we never discard any tokens - content_len = len(row) - row.extend([bos_token] * remaining) # Pad with BOS tokens - padded = True - break # Row is now full (with padding) + if len(row) == 0: + # No conversation fits even on an empty row. + # Truncate the shortest one to avoid fully padded rows. + trunc_idx = -1 + trunc_len = None + for i, conv in enumerate(conv_buffer): + conv_len = len(conv) + if conv_len > remaining and (trunc_len is None or conv_len < trunc_len): + trunc_idx = i + trunc_len = conv_len + conv = conv_buffer.pop(trunc_idx) + row.extend(conv[:row_capacity]) + consumed += ddp_world_size + else: + # No conversation fits - pad the remainder + content_len = len(row) + row.extend([bos_token] * remaining) # Pad with BOS tokens + padded = True + break # Row is now full (with padding or truncation) # Track content length: full row if no padding, otherwise the length before padding if padded: