mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 05:35:19 +00:00
avoid fully padded rows by truncating shortest convo
This commit is contained in:
parent
d4db003661
commit
7c0eb3f00b
|
|
@ -130,7 +130,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
|
||||
Each row in the batch starts with BOS (beginning of a conversation).
|
||||
Conversations are packed using best-fit algorithm. When no conversation fits,
|
||||
the row is padded (instead of cropping) to ensure no tokens are ever discarded.
|
||||
the row is padded. If the row is still empty, we fall back to truncating one
|
||||
conversation to avoid fully padded rows.
|
||||
Padding positions have targets masked with -1 (ignore_index for cross-entropy).
|
||||
"""
|
||||
global last_step, approx_progress, current_epoch
|
||||
|
|
@ -152,7 +153,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
nonlocal cursor, epoch
|
||||
while len(conv_buffer) < buffer_size:
|
||||
conversation = dataset[cursor]
|
||||
ids, _ = tokenizer.render_conversation(conversation, max_tokens=row_capacity)
|
||||
ids, _ = tokenizer.render_conversation(conversation)
|
||||
conv_buffer.append(ids)
|
||||
cursor += ddp_world_size
|
||||
if cursor >= dataset_size:
|
||||
|
|
@ -188,12 +189,25 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
row.extend(conv)
|
||||
consumed += ddp_world_size # Track actual consumption
|
||||
else:
|
||||
# No conversation fits - pad the remainder instead of cropping
|
||||
# This ensures we never discard any tokens
|
||||
content_len = len(row)
|
||||
row.extend([bos_token] * remaining) # Pad with BOS tokens
|
||||
padded = True
|
||||
break # Row is now full (with padding)
|
||||
if len(row) == 0:
|
||||
# No conversation fits even on an empty row.
|
||||
# Truncate the shortest one to avoid fully padded rows.
|
||||
trunc_idx = -1
|
||||
trunc_len = None
|
||||
for i, conv in enumerate(conv_buffer):
|
||||
conv_len = len(conv)
|
||||
if conv_len > remaining and (trunc_len is None or conv_len < trunc_len):
|
||||
trunc_idx = i
|
||||
trunc_len = conv_len
|
||||
conv = conv_buffer.pop(trunc_idx)
|
||||
row.extend(conv[:row_capacity])
|
||||
consumed += ddp_world_size
|
||||
else:
|
||||
# No conversation fits - pad the remainder
|
||||
content_len = len(row)
|
||||
row.extend([bos_token] * remaining) # Pad with BOS tokens
|
||||
padded = True
|
||||
break # Row is now full (with padding or truncation)
|
||||
|
||||
# Track content length: full row if no padding, otherwise the length before padding
|
||||
if padded:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user