diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 0b822e41..07a1eae8 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -238,6 +238,11 @@ class GPT(nn.Module): for i in range(n_layer): self.x0_lambdas.data[i] = 0.20 - (0.15 * i / max(n_layer - 1, 1)) + # Smear/backout scalars and smear gate must be explicitly initialized + torch.nn.init.zeros_(self.smear_lambda) + torch.nn.init.constant_(self.backout_lambda, 0.2) + torch.nn.init.uniform_(self.smear_gate.weight, 0.0, 0.02) + # Value embeddings (init like c_v: uniform with same std) for ve in self.value_embeds.values(): torch.nn.init.uniform_(ve.weight, -s, s)