mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-18 14:58:34 +00:00
Merge da507c5835 into 83dccc20ae
This commit is contained in:
commit
3a998fccf5
|
|
@ -282,8 +282,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
||||||
# Build tensors
|
# Build tensors
|
||||||
use_cuda = device_type == "cuda"
|
use_cuda = device_type == "cuda"
|
||||||
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
|
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
|
||||||
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
|
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda).contiguous()
|
||||||
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
|
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda).contiguous()
|
||||||
|
|
||||||
# Apply the loss mask from render_conversation (mask=1 for assistant completions,
|
# Apply the loss mask from render_conversation (mask=1 for assistant completions,
|
||||||
# mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns
|
# mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user