mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-07 09:50:28 +00:00
Merge branch 'master' into fix-scaling-zero-division
This commit is contained in:
commit
be723b7afb
|
|
@ -197,7 +197,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
row_capacity = args.max_seq_len + 1 # +1 for target at last position
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
|
||||
# Conversation buffer: list of token lists
|
||||
# Conversation buffer: list of (token_ids, loss_mask) tuples
|
||||
conv_buffer = []
|
||||
cursor = ddp_rank # Each rank processes different conversations (for fetching)
|
||||
consumed = ddp_rank # Track actual consumption separately from buffering
|
||||
|
|
@ -208,8 +208,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
nonlocal cursor, epoch
|
||||
while len(conv_buffer) < buffer_size:
|
||||
conversation = dataset[cursor]
|
||||
ids, _ = tokenizer.render_conversation(conversation)
|
||||
conv_buffer.append(ids)
|
||||
ids, mask = tokenizer.render_conversation(conversation)
|
||||
conv_buffer.append((ids, mask))
|
||||
cursor += ddp_world_size
|
||||
if cursor >= dataset_size:
|
||||
cursor = cursor % dataset_size
|
||||
|
|
@ -218,9 +218,11 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
|
||||
while True:
|
||||
rows = []
|
||||
mask_rows = []
|
||||
row_lengths = [] # Track actual content length (excluding padding) for each row
|
||||
for _ in range(args.device_batch_size):
|
||||
row = []
|
||||
mask_row = []
|
||||
padded = False
|
||||
while len(row) < row_capacity:
|
||||
# Ensure buffer has conversations
|
||||
|
|
@ -232,7 +234,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
# Find largest conversation that fits entirely
|
||||
best_idx = -1
|
||||
best_len = 0
|
||||
for i, conv in enumerate(conv_buffer):
|
||||
for i, (conv, _) in enumerate(conv_buffer):
|
||||
conv_len = len(conv)
|
||||
if conv_len <= remaining and conv_len > best_len:
|
||||
best_idx = i
|
||||
|
|
@ -240,14 +242,16 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
|
||||
if best_idx >= 0:
|
||||
# Found a conversation that fits - use it entirely
|
||||
conv = conv_buffer.pop(best_idx)
|
||||
conv, conv_mask = conv_buffer.pop(best_idx)
|
||||
row.extend(conv)
|
||||
mask_row.extend(conv_mask)
|
||||
consumed += ddp_world_size # Track actual consumption
|
||||
else:
|
||||
# No conversation fits - pad the remainder instead of cropping
|
||||
# This ensures we never discard any tokens
|
||||
content_len = len(row)
|
||||
row.extend([bos_token] * remaining) # Pad with BOS tokens
|
||||
mask_row.extend([0] * remaining)
|
||||
padded = True
|
||||
break # Row is now full (with padding)
|
||||
|
||||
|
|
@ -257,6 +261,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
else:
|
||||
row_lengths.append(row_capacity)
|
||||
rows.append(row[:row_capacity])
|
||||
mask_rows.append(mask_row[:row_capacity])
|
||||
|
||||
# Stopping condition to respect num_iterations, if given
|
||||
it += 1
|
||||
|
|
@ -280,6 +285,13 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
|
||||
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
|
||||
|
||||
# Apply the loss mask from render_conversation (mask=1 for assistant completions,
|
||||
# mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns
|
||||
# with targets (shifted by 1). Unmasked positions get -1 (ignore_index).
|
||||
mask_tensor = torch.tensor(mask_rows, dtype=torch.int8)
|
||||
mask_targets = mask_tensor[:, 1:].to(device=device)
|
||||
targets[mask_targets == 0] = -1
|
||||
|
||||
# Mask out padding positions in targets (set to -1 = ignore_index)
|
||||
# For each row, positions >= (content_length - 1) in targets should be masked
|
||||
for i, content_len in enumerate(row_lengths):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user