This commit is contained in:
Sofie Van Landeghem 2026-02-27 02:12:52 +01:00 committed by GitHub
commit 61fc8df11f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -277,8 +277,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
# Build tensors
use_cuda = device_type == "cuda"
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda).contiguous()
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda).contiguous()
# Mask out padding positions in targets (set to -1 = ignore_index)
# For each row, positions >= (content_length - 1) in targets should be masked