mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Merge 16788eed3c into 4a87a0d19f
This commit is contained in:
commit
6398858ea9
|
|
@ -265,13 +265,14 @@ class GPT(nn.Module):
|
||||||
# training mode: compute and return the loss
|
# training mode: compute and return the loss
|
||||||
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
|
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
|
||||||
logits = self.lm_head(x)
|
logits = self.lm_head(x)
|
||||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
|
||||||
logits = logits.float() # use tf32/fp32 for logits
|
logits = logits.float() # use tf32/fp32 for logits
|
||||||
|
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
||||||
return loss
|
return loss
|
||||||
else:
|
else:
|
||||||
# inference mode: compute and return the logits
|
# inference mode: compute and return the logits
|
||||||
logits = self.lm_head(x)
|
logits = self.lm_head(x)
|
||||||
|
logits = logits.float() # use tf32/fp32 for logits
|
||||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user