fix(model): apply float32 cast before logits softcapping

This change ensures that the logits softcapping operation (tanh) is performed in float32 precision rather than bfloat16. Previously, the code cast to float32 after the tanh operation, which meant the non-linearity was computed with bfloat16 precision
This commit is contained in:
spjosyula 2025-11-23 20:12:09 +05:30
parent 4a87a0d19f
commit 16788eed3c

View File

@ -265,13 +265,14 @@ class GPT(nn.Module):
# training mode: compute and return the loss # training mode: compute and return the loss
# TODO: experiment with Liger Kernels / chunked cross-entropy etc. # TODO: experiment with Liger Kernels / chunked cross-entropy etc.
logits = self.lm_head(x) logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap) # logits softcap
logits = logits.float() # use tf32/fp32 for logits logits = logits.float() # use tf32/fp32 for logits
logits = softcap * torch.tanh(logits / softcap) # logits softcap
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
return loss return loss
else: else:
# inference mode: compute and return the logits # inference mode: compute and return the logits
logits = self.lm_head(x) logits = self.lm_head(x)
logits = logits.float() # use tf32/fp32 for logits
logits = softcap * torch.tanh(logits / softcap) # logits softcap logits = softcap * torch.tanh(logits / softcap) # logits softcap
return logits return logits