diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index f91822d..f31a2d3 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -293,6 +293,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): targets[mask_targets == 0] = -1 # Mask out padding positions in targets (set to -1 = ignore_index) + # For each row, positions >= (content_length - 1) in targets should be masked for i, content_len in enumerate(row_lengths): if content_len < row_capacity: targets[i, content_len-1:] = -1