mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-17 22:38:32 +00:00
adding reply only loss for chat
This commit is contained in:
parent
daf7ec9156
commit
6ddd0602ed
|
|
@ -67,6 +67,8 @@ parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max pro
|
|||
# Data mixture
|
||||
parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)")
|
||||
parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)")
|
||||
parser.add_argument("--mask-user-prompts", action="store_true", help="mask user prompts in the target sequence")
|
||||
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -279,6 +281,15 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
|
||||
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
|
||||
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
|
||||
# Mask user prompts if requested
|
||||
if args.mask_user_prompts:
|
||||
user_start_id = 32760
|
||||
assistant_start_id = 32762
|
||||
starts = (targets == user_start_id).int()
|
||||
ends = (targets == assistant_start_id).int()
|
||||
# mask is 1 between user_start and assistant_start (inclusive of assistant_start)
|
||||
mask_span = (torch.cumsum(starts, dim=1) - torch.cumsum(ends, dim=1) + ends) > 0
|
||||
targets[mask_span] = -1
|
||||
|
||||
# Mask out padding positions in targets (set to -1 = ignore_index)
|
||||
# For each row, positions >= (content_length - 1) in targets should be masked
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user