From daf7ec9156813c651f2095b835d7982ddcfda8aa Mon Sep 17 00:00:00 2001 From: gpu-poor Date: Thu, 26 Feb 2026 10:07:09 +0000 Subject: [PATCH 1/5] printing steps count --- nanochat/loss_eval.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nanochat/loss_eval.py b/nanochat/loss_eval.py index 5a556e6..c983601 100644 --- a/nanochat/loss_eval.py +++ b/nanochat/loss_eval.py @@ -28,7 +28,8 @@ def evaluate_bpb(model, batches, steps, token_bytes): total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device()) total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device()) batch_iter = iter(batches) - for _ in range(steps): + for step in range(steps): + print(f"\reval {step+1}/{steps}", end="", flush=True) x, y = next(batch_iter) loss2d = model(x, y, loss_reduction='none') # (B, T) loss2d = loss2d.view(-1) # flatten @@ -51,6 +52,7 @@ def evaluate_bpb(model, batches, steps, token_bytes): num_bytes2d = token_bytes[y] total_nats += (loss2d * (num_bytes2d > 0)).sum() total_bytes += num_bytes2d.sum() + print() # newline after progress # sum reduce across all ranks world_size = dist.get_world_size() if dist.is_initialized() else 1 if world_size > 1: From 6ddd0602edd61ac1e0e9a1a048453e266b46f414 Mon Sep 17 00:00:00 2001 From: gpu-poor Date: Fri, 27 Feb 2026 01:41:14 +0530 Subject: [PATCH 2/5] adding reply only loss for chat --- scripts/chat_sft.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index a783ed2..28df1fd 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -67,6 +67,8 @@ parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max pro # Data mixture parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)") parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)") +parser.add_argument("--mask-user-prompts", action="store_true", help="mask user prompts in the target sequence") + args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- @@ -279,6 +281,15 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda) inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda) targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda) + # Mask user prompts if requested + if args.mask_user_prompts: + user_start_id = 32760 + assistant_start_id = 32762 + starts = (targets == user_start_id).int() + ends = (targets == assistant_start_id).int() + # mask is 1 between user_start and assistant_start (inclusive of assistant_start) + mask_span = (torch.cumsum(starts, dim=1) - torch.cumsum(ends, dim=1) + ends) > 0 + targets[mask_span] = -1 # Mask out padding positions in targets (set to -1 = ignore_index) # For each row, positions >= (content_length - 1) in targets should be masked From 5a06a7c597a305c52c22027491756d7dbea5155d Mon Sep 17 00:00:00 2001 From: gpu-poor Date: Sat, 28 Feb 2026 16:36:52 +0530 Subject: [PATCH 3/5] using the mask by render_conversation function of tokeniser --- scripts/chat_sft.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 28df1fd..f91822d 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -67,8 +67,6 @@ parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max pro # Data mixture parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)") parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)") -parser.add_argument("--mask-user-prompts", action="store_true", help="mask user prompts in the target sequence") - args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- @@ -199,7 +197,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): row_capacity = args.max_seq_len + 1 # +1 for target at last position bos_token = tokenizer.get_bos_token_id() - # Conversation buffer: list of token lists + # Conversation buffer: list of (token_ids, loss_mask) tuples conv_buffer = [] cursor = ddp_rank # Each rank processes different conversations (for fetching) consumed = ddp_rank # Track actual consumption separately from buffering @@ -210,8 +208,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): nonlocal cursor, epoch while len(conv_buffer) < buffer_size: conversation = dataset[cursor] - ids, _ = tokenizer.render_conversation(conversation) - conv_buffer.append(ids) + ids, mask = tokenizer.render_conversation(conversation) + conv_buffer.append((ids, mask)) cursor += ddp_world_size if cursor >= dataset_size: cursor = cursor % dataset_size @@ -220,9 +218,11 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): while True: rows = [] + mask_rows = [] row_lengths = [] # Track actual content length (excluding padding) for each row for _ in range(args.device_batch_size): row = [] + mask_row = [] padded = False while len(row) < row_capacity: # Ensure buffer has conversations @@ -234,7 +234,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): # Find largest conversation that fits entirely best_idx = -1 best_len = 0 - for i, conv in enumerate(conv_buffer): + for i, (conv, _) in enumerate(conv_buffer): conv_len = len(conv) if conv_len <= remaining and conv_len > best_len: best_idx = i @@ -242,14 +242,16 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): if best_idx >= 0: # Found a conversation that fits - use it entirely - conv = conv_buffer.pop(best_idx) + conv, conv_mask = conv_buffer.pop(best_idx) row.extend(conv) + mask_row.extend(conv_mask) consumed += ddp_world_size # Track actual consumption else: # No conversation fits - pad the remainder instead of cropping # This ensures we never discard any tokens content_len = len(row) row.extend([bos_token] * remaining) # Pad with BOS tokens + mask_row.extend([0] * remaining) padded = True break # Row is now full (with padding) @@ -259,6 +261,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): else: row_lengths.append(row_capacity) rows.append(row[:row_capacity]) + mask_rows.append(mask_row[:row_capacity]) # Stopping condition to respect num_iterations, if given it += 1 @@ -281,18 +284,15 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda) inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda) targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda) - # Mask user prompts if requested - if args.mask_user_prompts: - user_start_id = 32760 - assistant_start_id = 32762 - starts = (targets == user_start_id).int() - ends = (targets == assistant_start_id).int() - # mask is 1 between user_start and assistant_start (inclusive of assistant_start) - mask_span = (torch.cumsum(starts, dim=1) - torch.cumsum(ends, dim=1) + ends) > 0 - targets[mask_span] = -1 + + # Apply the loss mask from render_conversation (mask=1 for assistant completions, + # mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns + # with targets (shifted by 1). Unmasked positions get -1 (ignore_index). + mask_tensor = torch.tensor(mask_rows, dtype=torch.int8) + mask_targets = mask_tensor[:, 1:].to(device=device) + targets[mask_targets == 0] = -1 # Mask out padding positions in targets (set to -1 = ignore_index) - # For each row, positions >= (content_length - 1) in targets should be masked for i, content_len in enumerate(row_lengths): if content_len < row_capacity: targets[i, content_len-1:] = -1 From 014084a65602da747436afe425bb727f61986206 Mon Sep 17 00:00:00 2001 From: gpu-poor Date: Sun, 1 Mar 2026 13:16:18 +0530 Subject: [PATCH 4/5] undoing some changes --- nanochat/loss_eval.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nanochat/loss_eval.py b/nanochat/loss_eval.py index c983601..5a556e6 100644 --- a/nanochat/loss_eval.py +++ b/nanochat/loss_eval.py @@ -28,8 +28,7 @@ def evaluate_bpb(model, batches, steps, token_bytes): total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device()) total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device()) batch_iter = iter(batches) - for step in range(steps): - print(f"\reval {step+1}/{steps}", end="", flush=True) + for _ in range(steps): x, y = next(batch_iter) loss2d = model(x, y, loss_reduction='none') # (B, T) loss2d = loss2d.view(-1) # flatten @@ -52,7 +51,6 @@ def evaluate_bpb(model, batches, steps, token_bytes): num_bytes2d = token_bytes[y] total_nats += (loss2d * (num_bytes2d > 0)).sum() total_bytes += num_bytes2d.sum() - print() # newline after progress # sum reduce across all ranks world_size = dist.get_world_size() if dist.is_initialized() else 1 if world_size > 1: From 96b6d648951fa7023fcbc701e3856ab96c2a1a61 Mon Sep 17 00:00:00 2001 From: gpu-poor Date: Tue, 3 Mar 2026 00:21:16 +0530 Subject: [PATCH 5/5] putting back the comment which got removed accidently, no functionality change --- scripts/chat_sft.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index f91822d..f31a2d3 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -293,6 +293,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): targets[mask_targets == 0] = -1 # Mask out padding positions in targets (set to -1 = ignore_index) + # For each row, positions >= (content_length - 1) in targets should be masked for i, content_len in enumerate(row_lengths): if content_len < row_capacity: targets[i, content_len-1:] = -1