mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 05:35:19 +00:00
Merge 67d63de2e6 into a445144d39
This commit is contained in:
commit
115a6eb8dc
|
|
@ -190,7 +190,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
|
||||
Each row in the batch starts with BOS (beginning of a conversation).
|
||||
Conversations are packed using best-fit algorithm. When no conversation fits,
|
||||
the row is padded (instead of cropping) to ensure no tokens are ever discarded.
|
||||
the row is padded. If the row is still empty, we fall back to truncating one
|
||||
conversation to avoid fully padded rows.
|
||||
Padding positions have targets masked with -1 (ignore_index for cross-entropy).
|
||||
"""
|
||||
global last_step, approx_progress, current_epoch
|
||||
|
|
@ -251,13 +252,27 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
mask_row.extend(conv_mask)
|
||||
consumed += ddp_world_size # Track actual consumption
|
||||
else:
|
||||
# No conversation fits - pad the remainder instead of cropping
|
||||
# This ensures we never discard any tokens
|
||||
content_len = len(row)
|
||||
row.extend([bos_token] * remaining) # Pad with BOS tokens
|
||||
mask_row.extend([0] * remaining)
|
||||
padded = True
|
||||
break # Row is now full (with padding)
|
||||
if len(row) == 0:
|
||||
# No conversation fits even on an empty row.
|
||||
# Truncate the shortest oversized conversation to avoid fully padded rows.
|
||||
trunc_idx = -1
|
||||
trunc_len = None
|
||||
for i, (conv, _) in enumerate(conv_buffer):
|
||||
conv_len = len(conv)
|
||||
if conv_len > remaining and (trunc_len is None or conv_len < trunc_len):
|
||||
trunc_idx = i
|
||||
trunc_len = conv_len
|
||||
conv, conv_mask = conv_buffer.pop(trunc_idx)
|
||||
row.extend(conv[:row_capacity])
|
||||
mask_row.extend(conv_mask[:row_capacity])
|
||||
consumed += ddp_world_size
|
||||
else:
|
||||
# No conversation fits - pad the remainder instead of dropping tokens.
|
||||
content_len = len(row)
|
||||
row.extend([bos_token] * remaining)
|
||||
mask_row.extend([0] * remaining)
|
||||
padded = True
|
||||
break # Row is now full (with padding or truncation)
|
||||
|
||||
# Track content length: full row if no padding, otherwise the length before padding
|
||||
if padded:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user