diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 5a066b2..db36bdb 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -151,12 +151,111 @@ class Block(nn.Module): return x +class FlatRollEmbed(nn.Module): + """ + Embedding matrix W in R^{V x D} + Optimized for epistemic purposes + + """ + + def __init__(self, config, scale: str = "box", seed: int = 0, + freeze: bool = True, dtype=None, device=None): + super().__init__() + V = int(config.vocab_size) + D = int(config.n_embd) + dtype = dtype or torch.float32 + eps = 1e-12 + + # base vector x ∈ R^D with flat spectrum, DC=0 + x = self._make_base(D, scale=scale, seed=seed, dtype=dtype, device=device) # [D] + + # circulant-like rows: row r is x rolled by (r % D) + # (vectorized loop for clarity; can be optimized further if needed) + shifts = torch.arange(V, device=device) + rows = [torch.roll(x, shifts=int(s.item() % D), dims=0) for s in shifts] + W = torch.stack(rows, dim=0).to(dtype) + + # align a single positive extremum via a "tower" S + M = int(torch.argmax(x)) # index of max in x + pm = x[M].item() + N = 1.0 / (pm + eps) # safe reciprocal + + # S[r, (r + M) % D] = N + r_idx = torch.arange(V, device=device) + c_idx = (r_idx + M) % D + S = torch.zeros((V, D), dtype=dtype, device=device) + S[r_idx, c_idx] = N + + mixed = W + S + self.embed = nn.Embedding.from_pretrained(mixed, freeze=freeze) + + @staticmethod + def _make_base(D: int, scale: str = "box", seed: int = 0, + dtype=torch.float32, device=None) -> torch.Tensor: + """ + Build x ∈ R^D where |FFT(x)| is flat for k=1..D-1 and DC=0. + + scale: + - "unit": ||x||_2 = 1 + - "box": max|x_i| = 1 + """ + # Build on CPU, then move to device at the end. + # Use complex64 for float/bfloat/half; complex128 otherwise. + if dtype in (torch.float16, torch.bfloat16, torch.float32): + complex_dtype = torch.complex64 + work_float = torch.float32 + else: + complex_dtype = torch.complex128 + work_float = torch.float64 + + X = torch.zeros(D, dtype=complex_dtype) + + # DC bin = 0 + X[0] = torch.tensor(0, dtype=complex_dtype) + + if D % 2 == 0: + # bins 1..D/2-1 are complex-conjugate pairs; Nyquist bin must be real + for k in range(1, D // 2): + phi = torch.rand((), dtype=work_float) * (2 * math.pi) + val = torch.cos(phi).to(work_float) + 1j * torch.sin(phi).to(work_float) + val = val.to(complex_dtype) + X[k] = val + X[D - k] = torch.conj(val) + # Nyquist bin (real, ±1) + X[D // 2] = (1.0 if torch.rand(()) < 0.5 else -1.0) + else: + for k in range(1, (D - 1) // 2 + 1): + phi = torch.rand((), dtype=work_float) * (2 * math.pi) + val = torch.cos(phi).to(work_float) + 1j * torch.sin(phi).to(work_float) + val = val.to(complex_dtype) + X[k] = val + X[D - k] = torch.conj(val) + + x = torch.fft.ifft(X).real # real length-D base vector (float32/64) + x = x.to(work_float) + + if scale == "unit": + x = x / (x.norm() + 1e-12) + elif scale == "box": + x = x / (x.abs().max() + 1e-12) + else: + raise ValueError("scale must be 'unit' or 'box'") + + x = x.to(dtype) + if device is not None: + x = x.to(device) + return x + + def forward(self, input_ids: torch.LongTensor): + # (batch, seq_len, D) + return self.embed(input_ids) + class GPT(nn.Module): def __init__(self, config): super().__init__() self.config = config self.transformer = nn.ModuleDict({ - "wte": nn.Embedding(config.vocab_size, config.n_embd), + "wte": FlatRollEmbed(config) if config.frozen_embed else nn.Embedding(config.vocab_size, config.n_embd) "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), }) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)