Merge branch 'master' into fix-scaling-zero-division

This commit is contained in:
suraj-self 2026-03-03 11:22:31 +05:30
commit be723b7afb

View File

@ -197,7 +197,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
row_capacity = args.max_seq_len + 1 # +1 for target at last position
bos_token = tokenizer.get_bos_token_id()
# Conversation buffer: list of token lists
# Conversation buffer: list of (token_ids, loss_mask) tuples
conv_buffer = []
cursor = ddp_rank # Each rank processes different conversations (for fetching)
consumed = ddp_rank # Track actual consumption separately from buffering
@ -208,8 +208,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
nonlocal cursor, epoch
while len(conv_buffer) < buffer_size:
conversation = dataset[cursor]
ids, _ = tokenizer.render_conversation(conversation)
conv_buffer.append(ids)
ids, mask = tokenizer.render_conversation(conversation)
conv_buffer.append((ids, mask))
cursor += ddp_world_size
if cursor >= dataset_size:
cursor = cursor % dataset_size
@ -218,9 +218,11 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
while True:
rows = []
mask_rows = []
row_lengths = [] # Track actual content length (excluding padding) for each row
for _ in range(args.device_batch_size):
row = []
mask_row = []
padded = False
while len(row) < row_capacity:
# Ensure buffer has conversations
@ -232,7 +234,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
# Find largest conversation that fits entirely
best_idx = -1
best_len = 0
for i, conv in enumerate(conv_buffer):
for i, (conv, _) in enumerate(conv_buffer):
conv_len = len(conv)
if conv_len <= remaining and conv_len > best_len:
best_idx = i
@ -240,14 +242,16 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
if best_idx >= 0:
# Found a conversation that fits - use it entirely
conv = conv_buffer.pop(best_idx)
conv, conv_mask = conv_buffer.pop(best_idx)
row.extend(conv)
mask_row.extend(conv_mask)
consumed += ddp_world_size # Track actual consumption
else:
# No conversation fits - pad the remainder instead of cropping
# This ensures we never discard any tokens
content_len = len(row)
row.extend([bos_token] * remaining) # Pad with BOS tokens
mask_row.extend([0] * remaining)
padded = True
break # Row is now full (with padding)
@ -257,6 +261,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
else:
row_lengths.append(row_capacity)
rows.append(row[:row_capacity])
mask_rows.append(mask_row[:row_capacity])
# Stopping condition to respect num_iterations, if given
it += 1
@ -280,6 +285,13 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
# Apply the loss mask from render_conversation (mask=1 for assistant completions,
# mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns
# with targets (shifted by 1). Unmasked positions get -1 (ignore_index).
mask_tensor = torch.tensor(mask_rows, dtype=torch.int8)
mask_targets = mask_tensor[:, 1:].to(device=device)
targets[mask_targets == 0] = -1
# Mask out padding positions in targets (set to -1 = ignore_index)
# For each row, positions >= (content_length - 1) in targets should be masked
for i, content_len in enumerate(row_lengths):