apply float32 cast before logits softcapping so the tanh is in fp32. torch compile fuses this correctly with no extra memory costs.

This commit is contained in:
Andrej 2025-12-08 14:17:43 -08:00 committed by GitHub
commit cbf30c842c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -265,13 +265,14 @@ class GPT(nn.Module):
# training mode: compute and return the loss
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap) # logits softcap
logits = logits.float() # use tf32/fp32 for logits
logits = softcap * torch.tanh(logits / softcap) # logits softcap
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
return loss
else:
# inference mode: compute and return the logits
logits = self.lm_head(x)
logits = logits.float() # use tf32/fp32 for logits
logits = softcap * torch.tanh(logits / softcap) # logits softcap
return logits