Update gpt.py

Add an optimal anchored fixed embedding
This commit is contained in:
azuhanel 2025-10-13 23:12:33 -04:00 committed by GitHub
parent dd6ff9a1cc
commit 1240e1299e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -151,12 +151,111 @@ class Block(nn.Module):
return x return x
class FlatRollEmbed(nn.Module):
"""
Embedding matrix W in R^{V x D}
Optimized for epistemic purposes
"""
def __init__(self, config, scale: str = "box", seed: int = 0,
freeze: bool = True, dtype=None, device=None):
super().__init__()
V = int(config.vocab_size)
D = int(config.n_embd)
dtype = dtype or torch.float32
eps = 1e-12
# base vector x ∈ R^D with flat spectrum, DC=0
x = self._make_base(D, scale=scale, seed=seed, dtype=dtype, device=device) # [D]
# circulant-like rows: row r is x rolled by (r % D)
# (vectorized loop for clarity; can be optimized further if needed)
shifts = torch.arange(V, device=device)
rows = [torch.roll(x, shifts=int(s.item() % D), dims=0) for s in shifts]
W = torch.stack(rows, dim=0).to(dtype)
# align a single positive extremum via a "tower" S
M = int(torch.argmax(x)) # index of max in x
pm = x[M].item()
N = 1.0 / (pm + eps) # safe reciprocal
# S[r, (r + M) % D] = N
r_idx = torch.arange(V, device=device)
c_idx = (r_idx + M) % D
S = torch.zeros((V, D), dtype=dtype, device=device)
S[r_idx, c_idx] = N
mixed = W + S
self.embed = nn.Embedding.from_pretrained(mixed, freeze=freeze)
@staticmethod
def _make_base(D: int, scale: str = "box", seed: int = 0,
dtype=torch.float32, device=None) -> torch.Tensor:
"""
Build x R^D where |FFT(x)| is flat for k=1..D-1 and DC=0.
scale:
- "unit": ||x||_2 = 1
- "box": max|x_i| = 1
"""
# Build on CPU, then move to device at the end.
# Use complex64 for float/bfloat/half; complex128 otherwise.
if dtype in (torch.float16, torch.bfloat16, torch.float32):
complex_dtype = torch.complex64
work_float = torch.float32
else:
complex_dtype = torch.complex128
work_float = torch.float64
X = torch.zeros(D, dtype=complex_dtype)
# DC bin = 0
X[0] = torch.tensor(0, dtype=complex_dtype)
if D % 2 == 0:
# bins 1..D/2-1 are complex-conjugate pairs; Nyquist bin must be real
for k in range(1, D // 2):
phi = torch.rand((), dtype=work_float) * (2 * math.pi)
val = torch.cos(phi).to(work_float) + 1j * torch.sin(phi).to(work_float)
val = val.to(complex_dtype)
X[k] = val
X[D - k] = torch.conj(val)
# Nyquist bin (real, ±1)
X[D // 2] = (1.0 if torch.rand(()) < 0.5 else -1.0)
else:
for k in range(1, (D - 1) // 2 + 1):
phi = torch.rand((), dtype=work_float) * (2 * math.pi)
val = torch.cos(phi).to(work_float) + 1j * torch.sin(phi).to(work_float)
val = val.to(complex_dtype)
X[k] = val
X[D - k] = torch.conj(val)
x = torch.fft.ifft(X).real # real length-D base vector (float32/64)
x = x.to(work_float)
if scale == "unit":
x = x / (x.norm() + 1e-12)
elif scale == "box":
x = x / (x.abs().max() + 1e-12)
else:
raise ValueError("scale must be 'unit' or 'box'")
x = x.to(dtype)
if device is not None:
x = x.to(device)
return x
def forward(self, input_ids: torch.LongTensor):
# (batch, seq_len, D)
return self.embed(input_ids)
class GPT(nn.Module): class GPT(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = nn.ModuleDict({ self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd), "wte": FlatRollEmbed(config) if config.frozen_embed else nn.Embedding(config.vocab_size, config.n_embd)
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
}) })
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)