diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 4c81f06..82d8531 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -152,7 +152,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): nonlocal cursor, epoch while len(conv_buffer) < buffer_size: conversation = dataset[cursor] - ids, _ = tokenizer.render_conversation(conversation) + ids, _ = tokenizer.render_conversation(conversation, max_tokens=row_capacity) conv_buffer.append(ids) cursor += ddp_world_size if cursor >= dataset_size: