mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
fix(model): apply float32 cast before logits softcapping
This change ensures that the logits softcapping operation (tanh) is performed in float32 precision rather than bfloat16. Previously, the code cast to float32 after the tanh operation, which meant the non-linearity was computed with bfloat16 precision
This commit is contained in:
parent
4a87a0d19f
commit
16788eed3c
|
|
@ -265,13 +265,14 @@ class GPT(nn.Module):
|
||||||
# training mode: compute and return the loss
|
# training mode: compute and return the loss
|
||||||
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
|
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
|
||||||
logits = self.lm_head(x)
|
logits = self.lm_head(x)
|
||||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
|
||||||
logits = logits.float() # use tf32/fp32 for logits
|
logits = logits.float() # use tf32/fp32 for logits
|
||||||
|
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
||||||
return loss
|
return loss
|
||||||
else:
|
else:
|
||||||
# inference mode: compute and return the logits
|
# inference mode: compute and return the logits
|
||||||
logits = self.lm_head(x)
|
logits = self.lm_head(x)
|
||||||
|
logits = logits.float() # use tf32/fp32 for logits
|
||||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user