This commit is contained in:
Wooram Son 2026-03-24 16:23:35 -04:00 committed by GitHub
commit eca9e40d43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 0 deletions

View File

@ -469,6 +469,8 @@ class GPT(nn.Module):
if targets is not None:
# training: given the targets, compute and return the loss
# TODO experiment with chunked cross-entropy?
if loss_reduction == 'mean' and not (targets != -1).any():
return logits.sum() * 0.0
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
return loss
else:

61
tests/test_gpt_loss.py Normal file
View File

@ -0,0 +1,61 @@
import torch
import torch.nn.functional as F
from nanochat.gpt import GPT, GPTConfig
def build_test_model():
config = GPTConfig(
sequence_len=8,
vocab_size=32,
n_layer=2,
n_head=2,
n_kv_head=2,
n_embd=32,
window_pattern="L",
)
model = GPT(config)
model.init_weights()
return model
def test_forward_mean_loss_returns_graph_connected_zero_when_all_targets_ignored():
torch.manual_seed(0)
model = build_test_model()
idx = torch.randint(0, model.config.vocab_size, (2, 8))
targets = torch.full((2, 8), -1, dtype=torch.long)
loss = model(idx, targets)
assert loss.requires_grad
assert torch.isfinite(loss)
assert loss.item() == 0.0
loss.backward()
grads = [param.grad for param in model.parameters() if param.requires_grad]
assert grads, "expected trainable parameters"
assert all(grad is not None for grad in grads)
assert all(torch.count_nonzero(grad) == 0 for grad in grads)
def test_forward_mean_loss_matches_cross_entropy_on_non_ignored_targets():
torch.manual_seed(1)
model = build_test_model()
idx = torch.randint(0, model.config.vocab_size, (2, 8))
targets = torch.tensor(
[
[-1, -1, 3, 4, -1, 5, 6, -1],
[-1, 7, -1, -1, 8, 9, -1, 10],
],
dtype=torch.long,
)
logits = model(idx)
loss = model(idx, targets)
valid = targets != -1
expected = F.cross_entropy(logits[valid], targets[valid], reduction="mean")
assert torch.isfinite(loss)
assert torch.allclose(loss, expected)