From 3a611f0821246a5737735fe70aef5ef821832f4b Mon Sep 17 00:00:00 2001 From: Anton Chechetka Date: Sun, 23 Nov 2025 12:51:36 +0100 Subject: [PATCH] Clean up copypaste in GPT.forward() --- nanochat/gpt.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 216343c..3b83d01 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -260,19 +260,18 @@ class GPT(nn.Module): x = norm(x) # Forward the lm_head (compute logits) + logits = self.lm_head(x) softcap = 15 + logits = softcap * torch.tanh(logits / softcap) # logits softcap + if targets is not None: # training mode: compute and return the loss # TODO: experiment with Liger Kernels / chunked cross-entropy etc. - logits = self.lm_head(x) - logits = softcap * torch.tanh(logits / softcap) # logits softcap logits = logits.float() # use tf32/fp32 for logits loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) return loss else: - # inference mode: compute and return the logits - logits = self.lm_head(x) - logits = softcap * torch.tanh(logits / softcap) # logits softcap + # inference mode: only return the logits return logits @torch.inference_mode()