From 83dccc20aeab357ae4bb30d9c2ce938763efa929 Mon Sep 17 00:00:00 2001 From: Anish <12670807+gpu-poor@users.noreply.github.com> Date: Tue, 3 Mar 2026 06:07:47 +0530 Subject: [PATCH] Restore completion-only loss masking in SFT dataloader (#582) * printing steps count * adding reply only loss for chat * using the mask by render_conversation function of tokeniser * undoing some changes * putting back the comment which got removed accidently, no functionality change --- scripts/chat_sft.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index a783ed2..f31a2d3 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -197,7 +197,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): row_capacity = args.max_seq_len + 1 # +1 for target at last position bos_token = tokenizer.get_bos_token_id() - # Conversation buffer: list of token lists + # Conversation buffer: list of (token_ids, loss_mask) tuples conv_buffer = [] cursor = ddp_rank # Each rank processes different conversations (for fetching) consumed = ddp_rank # Track actual consumption separately from buffering @@ -208,8 +208,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): nonlocal cursor, epoch while len(conv_buffer) < buffer_size: conversation = dataset[cursor] - ids, _ = tokenizer.render_conversation(conversation) - conv_buffer.append(ids) + ids, mask = tokenizer.render_conversation(conversation) + conv_buffer.append((ids, mask)) cursor += ddp_world_size if cursor >= dataset_size: cursor = cursor % dataset_size @@ -218,9 +218,11 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): while True: rows = [] + mask_rows = [] row_lengths = [] # Track actual content length (excluding padding) for each row for _ in range(args.device_batch_size): row = [] + mask_row = [] padded = False while len(row) < row_capacity: # Ensure buffer has conversations @@ -232,7 +234,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): # Find largest conversation that fits entirely best_idx = -1 best_len = 0 - for i, conv in enumerate(conv_buffer): + for i, (conv, _) in enumerate(conv_buffer): conv_len = len(conv) if conv_len <= remaining and conv_len > best_len: best_idx = i @@ -240,14 +242,16 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): if best_idx >= 0: # Found a conversation that fits - use it entirely - conv = conv_buffer.pop(best_idx) + conv, conv_mask = conv_buffer.pop(best_idx) row.extend(conv) + mask_row.extend(conv_mask) consumed += ddp_world_size # Track actual consumption else: # No conversation fits - pad the remainder instead of cropping # This ensures we never discard any tokens content_len = len(row) row.extend([bos_token] * remaining) # Pad with BOS tokens + mask_row.extend([0] * remaining) padded = True break # Row is now full (with padding) @@ -257,6 +261,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): else: row_lengths.append(row_capacity) rows.append(row[:row_capacity]) + mask_rows.append(mask_row[:row_capacity]) # Stopping condition to respect num_iterations, if given it += 1 @@ -280,6 +285,13 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda) targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda) + # Apply the loss mask from render_conversation (mask=1 for assistant completions, + # mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns + # with targets (shifted by 1). Unmasked positions get -1 (ignore_index). + mask_tensor = torch.tensor(mask_rows, dtype=torch.int8) + mask_targets = mask_tensor[:, 1:].to(device=device) + targets[mask_targets == 0] = -1 + # Mask out padding positions in targets (set to -1 = ignore_index) # For each row, positions >= (content_length - 1) in targets should be masked for i, content_len in enumerate(row_lengths):