diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 216343c..3b83d01 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -260,19 +260,18 @@ class GPT(nn.Module): x = norm(x) # Forward the lm_head (compute logits) + logits = self.lm_head(x) softcap = 15 + logits = softcap * torch.tanh(logits / softcap) # logits softcap + if targets is not None: # training mode: compute and return the loss # TODO: experiment with Liger Kernels / chunked cross-entropy etc. - logits = self.lm_head(x) - logits = softcap * torch.tanh(logits / softcap) # logits softcap logits = logits.float() # use tf32/fp32 for logits loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) return loss else: - # inference mode: compute and return the logits - logits = self.lm_head(x) - logits = softcap * torch.tanh(logits / softcap) # logits softcap + # inference mode: only return the logits return logits @torch.inference_mode()