formatting

This commit is contained in:
Matěj Kripner 2025-12-09 12:48:46 +01:00
parent bbc57da7d5
commit d314e96aa2

View File

@ -265,8 +265,7 @@ class GPT(nn.Module):
# Forward the lm_head (compute logits)
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
# slice to remove padding
logits = logits[..., :self.config.vocab_size]
logits = logits[..., :self.config.vocab_size] # slice to remove padding
logits = logits.float() # switch to fp32 for logit softcap and loss computation
logits = softcap * torch.tanh(logits / softcap) # squash the logits