adding reply only loss for chat

This commit is contained in:
gpu-poor 2026-02-27 01:41:14 +05:30
parent daf7ec9156
commit 6ddd0602ed

View File

@ -67,6 +67,8 @@ parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max pro
# Data mixture
parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)")
parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)")
parser.add_argument("--mask-user-prompts", action="store_true", help="mask user prompts in the target sequence")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
@ -279,6 +281,15 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
# Mask user prompts if requested
if args.mask_user_prompts:
user_start_id = 32760
assistant_start_id = 32762
starts = (targets == user_start_id).int()
ends = (targets == assistant_start_id).int()
# mask is 1 between user_start and assistant_start (inclusive of assistant_start)
mask_span = (torch.cumsum(starts, dim=1) - torch.cumsum(ends, dim=1) + ends) > 0
targets[mask_span] = -1
# Mask out padding positions in targets (set to -1 = ignore_index)
# For each row, positions >= (content_length - 1) in targets should be masked