mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
Update gpt.py
Add an optimal anchored fixed embedding
This commit is contained in:
parent
dd6ff9a1cc
commit
1240e1299e
101
nanochat/gpt.py
101
nanochat/gpt.py
|
|
@ -151,12 +151,111 @@ class Block(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FlatRollEmbed(nn.Module):
|
||||||
|
"""
|
||||||
|
Embedding matrix W in R^{V x D}
|
||||||
|
Optimized for epistemic purposes
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, scale: str = "box", seed: int = 0,
|
||||||
|
freeze: bool = True, dtype=None, device=None):
|
||||||
|
super().__init__()
|
||||||
|
V = int(config.vocab_size)
|
||||||
|
D = int(config.n_embd)
|
||||||
|
dtype = dtype or torch.float32
|
||||||
|
eps = 1e-12
|
||||||
|
|
||||||
|
# base vector x ∈ R^D with flat spectrum, DC=0
|
||||||
|
x = self._make_base(D, scale=scale, seed=seed, dtype=dtype, device=device) # [D]
|
||||||
|
|
||||||
|
# circulant-like rows: row r is x rolled by (r % D)
|
||||||
|
# (vectorized loop for clarity; can be optimized further if needed)
|
||||||
|
shifts = torch.arange(V, device=device)
|
||||||
|
rows = [torch.roll(x, shifts=int(s.item() % D), dims=0) for s in shifts]
|
||||||
|
W = torch.stack(rows, dim=0).to(dtype)
|
||||||
|
|
||||||
|
# align a single positive extremum via a "tower" S
|
||||||
|
M = int(torch.argmax(x)) # index of max in x
|
||||||
|
pm = x[M].item()
|
||||||
|
N = 1.0 / (pm + eps) # safe reciprocal
|
||||||
|
|
||||||
|
# S[r, (r + M) % D] = N
|
||||||
|
r_idx = torch.arange(V, device=device)
|
||||||
|
c_idx = (r_idx + M) % D
|
||||||
|
S = torch.zeros((V, D), dtype=dtype, device=device)
|
||||||
|
S[r_idx, c_idx] = N
|
||||||
|
|
||||||
|
mixed = W + S
|
||||||
|
self.embed = nn.Embedding.from_pretrained(mixed, freeze=freeze)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_base(D: int, scale: str = "box", seed: int = 0,
|
||||||
|
dtype=torch.float32, device=None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Build x ∈ R^D where |FFT(x)| is flat for k=1..D-1 and DC=0.
|
||||||
|
|
||||||
|
scale:
|
||||||
|
- "unit": ||x||_2 = 1
|
||||||
|
- "box": max|x_i| = 1
|
||||||
|
"""
|
||||||
|
# Build on CPU, then move to device at the end.
|
||||||
|
# Use complex64 for float/bfloat/half; complex128 otherwise.
|
||||||
|
if dtype in (torch.float16, torch.bfloat16, torch.float32):
|
||||||
|
complex_dtype = torch.complex64
|
||||||
|
work_float = torch.float32
|
||||||
|
else:
|
||||||
|
complex_dtype = torch.complex128
|
||||||
|
work_float = torch.float64
|
||||||
|
|
||||||
|
X = torch.zeros(D, dtype=complex_dtype)
|
||||||
|
|
||||||
|
# DC bin = 0
|
||||||
|
X[0] = torch.tensor(0, dtype=complex_dtype)
|
||||||
|
|
||||||
|
if D % 2 == 0:
|
||||||
|
# bins 1..D/2-1 are complex-conjugate pairs; Nyquist bin must be real
|
||||||
|
for k in range(1, D // 2):
|
||||||
|
phi = torch.rand((), dtype=work_float) * (2 * math.pi)
|
||||||
|
val = torch.cos(phi).to(work_float) + 1j * torch.sin(phi).to(work_float)
|
||||||
|
val = val.to(complex_dtype)
|
||||||
|
X[k] = val
|
||||||
|
X[D - k] = torch.conj(val)
|
||||||
|
# Nyquist bin (real, ±1)
|
||||||
|
X[D // 2] = (1.0 if torch.rand(()) < 0.5 else -1.0)
|
||||||
|
else:
|
||||||
|
for k in range(1, (D - 1) // 2 + 1):
|
||||||
|
phi = torch.rand((), dtype=work_float) * (2 * math.pi)
|
||||||
|
val = torch.cos(phi).to(work_float) + 1j * torch.sin(phi).to(work_float)
|
||||||
|
val = val.to(complex_dtype)
|
||||||
|
X[k] = val
|
||||||
|
X[D - k] = torch.conj(val)
|
||||||
|
|
||||||
|
x = torch.fft.ifft(X).real # real length-D base vector (float32/64)
|
||||||
|
x = x.to(work_float)
|
||||||
|
|
||||||
|
if scale == "unit":
|
||||||
|
x = x / (x.norm() + 1e-12)
|
||||||
|
elif scale == "box":
|
||||||
|
x = x / (x.abs().max() + 1e-12)
|
||||||
|
else:
|
||||||
|
raise ValueError("scale must be 'unit' or 'box'")
|
||||||
|
|
||||||
|
x = x.to(dtype)
|
||||||
|
if device is not None:
|
||||||
|
x = x.to(device)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, input_ids: torch.LongTensor):
|
||||||
|
# (batch, seq_len, D)
|
||||||
|
return self.embed(input_ids)
|
||||||
|
|
||||||
class GPT(nn.Module):
|
class GPT(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.transformer = nn.ModuleDict({
|
self.transformer = nn.ModuleDict({
|
||||||
"wte": nn.Embedding(config.vocab_size, config.n_embd),
|
"wte": FlatRollEmbed(config) if config.frozen_embed else nn.Embedding(config.vocab_size, config.n_embd)
|
||||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||||
})
|
})
|
||||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user