diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 5487047f..208acd14 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -416,7 +416,7 @@ class GPT(nn.Module): if targets is not None: # training: given the targets, compute and return the loss # TODO experiment with chunked cross-entropy? - loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=-1, reduction=loss_reduction) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) return loss else: # inference: just return the logits directly diff --git a/nanochat/loss_eval.py b/nanochat/loss_eval.py index 6f59b110..5a556e6c 100644 --- a/nanochat/loss_eval.py +++ b/nanochat/loss_eval.py @@ -32,7 +32,7 @@ def evaluate_bpb(model, batches, steps, token_bytes): x, y = next(batch_iter) loss2d = model(x, y, loss_reduction='none') # (B, T) loss2d = loss2d.view(-1) # flatten - y = y.reshape(-1) # flatten + y = y.view(-1) # flatten if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32 # slightly more complex code path if some target tokens are ignore_index (e.g. -1) # any target token < 0 is to be ignored: do NOT index token_bytes with negatives diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index a783ed21..aad23ce2 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -277,8 +277,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): # Build tensors use_cuda = device_type == "cuda" batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda) - inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda) - targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda) + inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda).contiguous() + targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda).contiguous() # Mask out padding positions in targets (set to -1 = ignore_index) # For each row, positions >= (content_length - 1) in targets should be masked