diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 52f7075..68a5fed 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -260,20 +260,18 @@ class GPT(nn.Module): x = norm(x) # Forward the lm_head (compute logits) - softcap = 15 + softcap = 15 # smoothly cap the logits to the range [-softcap, softcap] + logits = self.lm_head(x) # (B, T, vocab_size) <- very big tensor, large amount of memory + logits = logits.float() # switch to fp32 for logit softcap and loss computation + logits = softcap * torch.tanh(logits / softcap) # squash the logits + if targets is not None: - # training mode: compute and return the loss - # TODO: experiment with Liger Kernels / chunked cross-entropy etc. - logits = self.lm_head(x) - logits = logits.float() # use tf32/fp32 for logits - logits = softcap * torch.tanh(logits / softcap) # logits softcap + # training: given the targets, compute and return the loss + # TODO experiment with chunked cross-entropy? loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) return loss else: - # inference mode: compute and return the logits - logits = self.lm_head(x) - logits = logits.float() # use tf32/fp32 for logits - logits = softcap * torch.tanh(logits / softcap) # logits softcap + # inference: just return the logits directly return logits @torch.inference_mode()