mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-16 11:09:09 +00:00
Merge 67d63de2e6 into 1b1cc3c599
This commit is contained in:
commit
087d48bd61
|
|
@ -190,7 +190,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
||||||
|
|
||||||
Each row in the batch starts with BOS (beginning of a conversation).
|
Each row in the batch starts with BOS (beginning of a conversation).
|
||||||
Conversations are packed using best-fit algorithm. When no conversation fits,
|
Conversations are packed using best-fit algorithm. When no conversation fits,
|
||||||
the row is padded (instead of cropping) to ensure no tokens are ever discarded.
|
the row is padded. If the row is still empty, we fall back to truncating one
|
||||||
|
conversation to avoid fully padded rows.
|
||||||
Padding positions have targets masked with -1 (ignore_index for cross-entropy).
|
Padding positions have targets masked with -1 (ignore_index for cross-entropy).
|
||||||
"""
|
"""
|
||||||
global last_step, approx_progress, current_epoch
|
global last_step, approx_progress, current_epoch
|
||||||
|
|
@ -251,13 +252,27 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
||||||
mask_row.extend(conv_mask)
|
mask_row.extend(conv_mask)
|
||||||
consumed += ddp_world_size # Track actual consumption
|
consumed += ddp_world_size # Track actual consumption
|
||||||
else:
|
else:
|
||||||
# No conversation fits - pad the remainder instead of cropping
|
if len(row) == 0:
|
||||||
# This ensures we never discard any tokens
|
# No conversation fits even on an empty row.
|
||||||
content_len = len(row)
|
# Truncate the shortest oversized conversation to avoid fully padded rows.
|
||||||
row.extend([bos_token] * remaining) # Pad with BOS tokens
|
trunc_idx = -1
|
||||||
mask_row.extend([0] * remaining)
|
trunc_len = None
|
||||||
padded = True
|
for i, (conv, _) in enumerate(conv_buffer):
|
||||||
break # Row is now full (with padding)
|
conv_len = len(conv)
|
||||||
|
if conv_len > remaining and (trunc_len is None or conv_len < trunc_len):
|
||||||
|
trunc_idx = i
|
||||||
|
trunc_len = conv_len
|
||||||
|
conv, conv_mask = conv_buffer.pop(trunc_idx)
|
||||||
|
row.extend(conv[:row_capacity])
|
||||||
|
mask_row.extend(conv_mask[:row_capacity])
|
||||||
|
consumed += ddp_world_size
|
||||||
|
else:
|
||||||
|
# No conversation fits - pad the remainder instead of dropping tokens.
|
||||||
|
content_len = len(row)
|
||||||
|
row.extend([bos_token] * remaining)
|
||||||
|
mask_row.extend([0] * remaining)
|
||||||
|
padded = True
|
||||||
|
break # Row is now full (with padding or truncation)
|
||||||
|
|
||||||
# Track content length: full row if no padding, otherwise the length before padding
|
# Track content length: full row if no padding, otherwise the length before padding
|
||||||
if padded:
|
if padded:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user