diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 0b822e41..07a1eae8 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -238,6 +238,11 @@ class GPT(nn.Module): for i in range(n_layer): self.x0_lambdas.data[i] = 0.20 - (0.15 * i / max(n_layer - 1, 1)) + # Smear/backout scalars and smear gate must be explicitly initialized + torch.nn.init.zeros_(self.smear_lambda) + torch.nn.init.constant_(self.backout_lambda, 0.2) + torch.nn.init.uniform_(self.smear_gate.weight, 0.0, 0.02) + # Value embeddings (init like c_v: uniform with same std) for ve in self.value_embeds.values(): torch.nn.init.uniform_(ve.weight, -s, s) diff --git a/pyproject.toml b/pyproject.toml index a6e2cca6..0527369f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ explicit = true [project.optional-dependencies] cpu = [ + "setuptools>=65.0.0", "torch==2.9.1", ] gpu = [ diff --git a/uv.lock b/uv.lock index 94558149..c81d3303 100644 --- a/uv.lock +++ b/uv.lock @@ -1507,6 +1507,7 @@ dependencies = [ [package.optional-dependencies] cpu = [ + { name = "setuptools" }, { name = "torch", version = "2.9.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, { name = "torch", version = "2.9.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] @@ -1530,6 +1531,7 @@ requires-dist = [ { name = "kernels", specifier = ">=0.11.7" }, { name = "psutil", specifier = ">=7.1.0" }, { name = "rustbpe", specifier = ">=0.1.0" }, + { name = "setuptools", marker = "extra == 'cpu'", specifier = ">=65.0.0" }, { name = "tiktoken", specifier = ">=0.11.0" }, { name = "tokenizers", specifier = ">=0.22.0" }, { name = "torch", specifier = "==2.9.1" },