diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 216343c..52f7075 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -265,13 +265,14 @@ class GPT(nn.Module): # training mode: compute and return the loss # TODO: experiment with Liger Kernels / chunked cross-entropy etc. logits = self.lm_head(x) - logits = softcap * torch.tanh(logits / softcap) # logits softcap logits = logits.float() # use tf32/fp32 for logits + logits = softcap * torch.tanh(logits / softcap) # logits softcap loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) return loss else: # inference mode: compute and return the logits logits = self.lm_head(x) + logits = logits.float() # use tf32/fp32 for logits logits = softcap * torch.tanh(logits / softcap) # logits softcap return logits