diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index a783ed21..aad23ce2 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -277,8 +277,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): # Build tensors use_cuda = device_type == "cuda" batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda) - inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda) - targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda) + inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda).contiguous() + targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda).contiguous() # Mask out padding positions in targets (set to -1 = ignore_index) # For each row, positions >= (content_length - 1) in targets should be masked