From bffdb2ef919b21cf576b0d55fdec725b94f89f85 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 9 Dec 2025 02:01:05 +0000 Subject: [PATCH] group common code to make things neater in gpt logit computation --- nanochat/gpt.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 52f7075..68a5fed 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -260,20 +260,18 @@ class GPT(nn.Module): x = norm(x) # Forward the lm_head (compute logits) - softcap = 15 + softcap = 15 # smoothly cap the logits to the range [-softcap, softcap] + logits = self.lm_head(x) # (B, T, vocab_size) <- very big tensor, large amount of memory + logits = logits.float() # switch to fp32 for logit softcap and loss computation + logits = softcap * torch.tanh(logits / softcap) # squash the logits + if targets is not None: - # training mode: compute and return the loss - # TODO: experiment with Liger Kernels / chunked cross-entropy etc. - logits = self.lm_head(x) - logits = logits.float() # use tf32/fp32 for logits - logits = softcap * torch.tanh(logits / softcap) # logits softcap + # training: given the targets, compute and return the loss + # TODO experiment with chunked cross-entropy? loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) return loss else: - # inference mode: compute and return the logits - logits = self.lm_head(x) - logits = logits.float() # use tf32/fp32 for logits - logits = softcap * torch.tanh(logits / softcap) # logits softcap + # inference: just return the logits directly return logits @torch.inference_mode()