mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-18 12:09:09 +00:00
fix directly in data loader instead
This commit is contained in:
parent
83de1b18b1
commit
da507c5835
|
|
@ -416,7 +416,7 @@ class GPT(nn.Module):
|
|||
if targets is not None:
|
||||
# training: given the targets, compute and return the loss
|
||||
# TODO experiment with chunked cross-entropy?
|
||||
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=-1, reduction=loss_reduction)
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
||||
return loss
|
||||
else:
|
||||
# inference: just return the logits directly
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
|||
x, y = next(batch_iter)
|
||||
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
||||
loss2d = loss2d.view(-1) # flatten
|
||||
y = y.reshape(-1) # flatten
|
||||
y = y.view(-1) # flatten
|
||||
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
|
||||
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
|
||||
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
|
||||
|
|
|
|||
|
|
@ -277,8 +277,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
# Build tensors
|
||||
use_cuda = device_type == "cuda"
|
||||
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
|
||||
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
|
||||
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
|
||||
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda).contiguous()
|
||||
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda).contiguous()
|
||||
|
||||
# Mask out padding positions in targets (set to -1 = ignore_index)
|
||||
# For each row, positions >= (content_length - 1) in targets should be masked
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user