This commit is contained in:
Anton Chechetka 2025-11-23 12:59:58 +01:00 committed by GitHub
commit bfa37c8723
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -260,19 +260,18 @@ class GPT(nn.Module):
x = norm(x) x = norm(x)
# Forward the lm_head (compute logits) # Forward the lm_head (compute logits)
logits = self.lm_head(x)
softcap = 15 softcap = 15
logits = softcap * torch.tanh(logits / softcap) # logits softcap
if targets is not None: if targets is not None:
# training mode: compute and return the loss # training mode: compute and return the loss
# TODO: experiment with Liger Kernels / chunked cross-entropy etc. # TODO: experiment with Liger Kernels / chunked cross-entropy etc.
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap) # logits softcap
logits = logits.float() # use tf32/fp32 for logits logits = logits.float() # use tf32/fp32 for logits
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
return loss return loss
else: else:
# inference mode: compute and return the logits # inference mode: only return the logits
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap) # logits softcap
return logits return logits
@torch.inference_mode() @torch.inference_mode()