From 16788eed3cc3a79a94fa5bb852db722f79852cb7 Mon Sep 17 00:00:00 2001 From: spjosyula Date: Sun, 23 Nov 2025 20:12:09 +0530 Subject: [PATCH] fix(model): apply float32 cast before logits softcapping This change ensures that the logits softcapping operation (tanh) is performed in float32 precision rather than bfloat16. Previously, the code cast to float32 after the tanh operation, which meant the non-linearity was computed with bfloat16 precision --- nanochat/gpt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 216343c..52f7075 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -265,13 +265,14 @@ class GPT(nn.Module): # training mode: compute and return the loss # TODO: experiment with Liger Kernels / chunked cross-entropy etc. logits = self.lm_head(x) - logits = softcap * torch.tanh(logits / softcap) # logits softcap logits = logits.float() # use tf32/fp32 for logits + logits = softcap * torch.tanh(logits / softcap) # logits softcap loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) return loss else: # inference mode: compute and return the logits logits = self.lm_head(x) + logits = logits.float() # use tf32/fp32 for logits logits = softcap * torch.tanh(logits / softcap) # logits softcap return logits