Instrument main gpt model with logging

This commit is contained in:
Wollaston 2025-11-30 14:11:33 -05:00
parent 8a40915246
commit 62cfe4d4c3

View File

@ -12,16 +12,17 @@ Notable features:
"""
import math
from functools import partial
from dataclasses import dataclass
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from logger import log
from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
from nanochat.common import get_dist_info, print0
from nanochat.muon import DistMuon, Muon
from nanochat.vsa import HRROperations
@ -40,12 +41,13 @@ class GPTConfig:
d_vsa: int = 256
@log
def norm(x):
# Purely functional rmsnorm with no learnable params
return F.rms_norm(x, (x.size(-1),))
@log
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
@ -58,6 +60,7 @@ def apply_rotary_emb(x, cos, sin):
class CausalSelfAttention(nn.Module):
@log
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
@ -72,6 +75,7 @@ class CausalSelfAttention(nn.Module):
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
@log
def forward(self, x, cos_sin, kv_cache):
B, T, C = x.size()
@ -108,19 +112,15 @@ class CausalSelfAttention(nn.Module):
repeat_factor = self.n_head // self.n_kv_head
k = k.repeat_interleave(repeat_factor, dim=1)
v = v.repeat_interleave(repeat_factor, dim=1)
if kv_cache is None or Tq == Tk:
# During training (no KV cache), attend as usual with causal attention
# And even if there is KV cache, we can still use this simple version when Tq == Tk
y = F.scaled_dot_product_attention(
q, k, v, is_causal=True
)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
elif Tq == 1:
# During inference but with a single query in this forward pass:
# The query has to attend to all the keys/values in the cache
y = F.scaled_dot_product_attention(
q, k, v, is_causal=False
)
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
else:
# During inference AND we have a chunk of queries in this forward pass:
# First, each query attends to all the cached keys/values (i.e. full prefix)
@ -134,9 +134,7 @@ class CausalSelfAttention(nn.Module):
attn_mask[:, prefix_len:] = torch.tril(
torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)
)
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask
)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
# Re-assemble the heads side by side and project back to residual stream
y = y.transpose(1, 2).contiguous().view(B, T, -1)
@ -145,136 +143,180 @@ class CausalSelfAttention(nn.Module):
class MLP(nn.Module):
@log
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
@log
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x
class Router(nn.Module):
@log
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.top_k
self.use_vsa = config.use_vsa
self.d_vsa = config.d_vsa
# Normal routing gate
self.w_gate = nn.Linear(config.n_embd, self.num_experts, bias=False)
# VSA routing components
if self.use_vsa:
self.vsa_ops = HRROperations()
# Project tokens to VSA space
self.token_to_vsa = nn.Linear(config.n_embd, self.d_vsa, bias=False)
# Fixed expert signatures in VSA space (not learnable)
self.register_buffer('expert_signatures', torch.randn(self.num_experts, self.d_vsa))
self.register_buffer(
"expert_signatures", torch.randn(self.num_experts, self.d_vsa)
)
# Single VSA superposition cache containing all token-expert bindings
self.register_buffer('vsa_cache', torch.zeros(self.d_vsa))
self.register_buffer("vsa_cache", torch.zeros(self.d_vsa))
# Combine normal and VSA routing
self.vsa_weight = nn.Parameter(torch.tensor(0.5)) # Learnable weight for VSA contribution
self.vsa_weight = nn.Parameter(
torch.tensor(0.5)
) # Learnable weight for VSA contribution
@log
def forward(self, x):
# x: (B, T, D)
B, T, D = x.size()
# Normal routing to determine which experts are activated
normal_logits = self.w_gate(x) # (B, T, num_experts)
topk_logits, topk_indices = torch.topk(normal_logits, self.top_k, dim=-1) # (B, T, top_k)
topk_logits, topk_indices = torch.topk(
normal_logits, self.top_k, dim=-1
) # (B, T, top_k)
if self.use_vsa:
# Project tokens to VSA space
token_vsa = self.token_to_vsa(x) # (B, T, d_vsa)
# Update VSA cache: bind activated tokens to their expert signatures and add to single superposition
# Only update cache during training mode
if self.training:
# Vectorized cache update
# Flatten indices and tokens for batch processing
flat_tokens = token_vsa.unsqueeze(2).expand(-1, -1, self.top_k, -1) # (B, T, top_k, d_vsa)
flat_tokens = token_vsa.unsqueeze(2).expand(
-1, -1, self.top_k, -1
) # (B, T, top_k, d_vsa)
flat_tokens = flat_tokens.reshape(-1, self.d_vsa) # (B*T*top_k, d_vsa)
flat_indices = topk_indices.reshape(-1) # (B*T*top_k,)
# Get corresponding expert signatures
expert_sigs = self.expert_signatures[flat_indices] # (B*T*top_k, d_vsa)
# Vectorized binding using VSA operations
bound_tokens = self.vsa_ops.bind(flat_tokens, expert_sigs)
# Sum all bound tokens and add to cache
cache_update = torch.sum(bound_tokens, dim=0)
self.vsa_cache = self.vsa_cache + cache_update.detach()
# Compute VSA routing scores by querying the single superposition
vsa_scores = torch.zeros(B, T, self.num_experts, device=x.device, dtype=x.dtype)
vsa_scores = torch.zeros(
B, T, self.num_experts, device=x.device, dtype=x.dtype
)
if torch.norm(self.vsa_cache) > 1e-8: # Only if cache has content
# Vectorized similarity computation
# Expand tokens to (B, T, num_experts, d_vsa) and expert_signatures to (B, T, num_experts, d_vsa)
token_expanded = token_vsa.unsqueeze(2).expand(B, T, self.num_experts, self.d_vsa) # (B, T, num_experts, d_vsa)
expert_expanded = self.expert_signatures.unsqueeze(0).unsqueeze(0).expand(B, T, self.num_experts, self.d_vsa) # (B, T, num_experts, d_vsa)
token_expanded = token_vsa.unsqueeze(2).expand(
B, T, self.num_experts, self.d_vsa
) # (B, T, num_experts, d_vsa)
expert_expanded = (
self.expert_signatures.unsqueeze(0)
.unsqueeze(0)
.expand(B, T, self.num_experts, self.d_vsa)
) # (B, T, num_experts, d_vsa)
# Vectorized binding: bind all tokens with all expert signatures
queries = self.vsa_ops.bind(token_expanded.reshape(-1, self.d_vsa), expert_expanded.reshape(-1, self.d_vsa)) # (B*T*num_experts, d_vsa)
queries = queries.reshape(B, T, self.num_experts, self.d_vsa) # (B, T, num_experts, d_vsa)
queries = self.vsa_ops.bind(
token_expanded.reshape(-1, self.d_vsa),
expert_expanded.reshape(-1, self.d_vsa),
) # (B*T*num_experts, d_vsa)
queries = queries.reshape(
B, T, self.num_experts, self.d_vsa
) # (B, T, num_experts, d_vsa)
# Compute similarities with the cache for all queries at once
# Flatten queries and compute all similarities in one batch
queries_flat = queries.reshape(-1, self.d_vsa) # (B*T*num_experts, d_vsa)
cache_expanded = self.vsa_cache.unsqueeze(0).expand(queries_flat.size(0), -1) # (B*T*num_experts, d_vsa)
queries_flat = queries.reshape(
-1, self.d_vsa
) # (B*T*num_experts, d_vsa)
cache_expanded = self.vsa_cache.unsqueeze(0).expand(
queries_flat.size(0), -1
) # (B*T*num_experts, d_vsa)
# Use VSA similarity operation for all queries
similarities_flat = self.vsa_ops.similarity(queries_flat, cache_expanded) # (B*T*num_experts,)
similarities_flat = self.vsa_ops.similarity(
queries_flat, cache_expanded
) # (B*T*num_experts,)
vsa_scores = similarities_flat.reshape(B, T, self.num_experts)
# Combine normal and VSA routing with learnable weight
combined_logits = normal_logits + torch.sigmoid(self.vsa_weight) * vsa_scores
topk_logits, topk_indices = torch.topk(combined_logits, self.top_k, dim=-1) # (B, T, top_k)
combined_logits = (
normal_logits + torch.sigmoid(self.vsa_weight) * vsa_scores
)
topk_logits, topk_indices = torch.topk(
combined_logits, self.top_k, dim=-1
) # (B, T, top_k)
topk_gates = F.softmax(topk_logits, dim=-1) # (B, T, top_k)
return topk_indices, topk_gates # both are (B, T, top_k)
class MoE(nn.Module):
@log
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.experts = nn.ModuleList([MLP(config) for _ in range(self.num_experts)])
self.router = Router(config)
@log
def forward(self, x):
# x: (B, T, D)
B, T, D = x.size()
topk_indices, topk_gates = self.router(x) # both are (B, T, top_k)
output = torch.zeros_like(x)
# Flatten for easier token-level processing
x_flat = x.view(-1, D) # (B*T, D)
output_flat = output.view(-1, D) # (B*T, D)
topk_indices_flat = topk_indices.view(-1, topk_indices.size(-1)) # (B*T, top_k)
topk_gates_flat = topk_gates.view(-1, topk_gates.size(-1)) # (B*T, top_k)
for i in range(self.num_experts):
# Find which tokens use expert i and their corresponding gate values
expert_mask = (topk_indices_flat == i) # (B*T, top_k)
expert_mask = topk_indices_flat == i # (B*T, top_k)
if expert_mask.any():
# Get tokens that route to this expert
token_indices, k_indices = torch.where(expert_mask)
if len(token_indices) > 0:
selected_tokens = x_flat[token_indices] # (num_selected, D)
expert_output = self.experts[i](selected_tokens) # (num_selected, D)
gate_weights = topk_gates_flat[token_indices, k_indices].unsqueeze(-1) # (num_selected, 1)
expert_output = self.experts[i](
selected_tokens
) # (num_selected, D)
gate_weights = topk_gates_flat[token_indices, k_indices].unsqueeze(
-1
) # (num_selected, 1)
# Add weighted expert output back to the corresponding tokens
output_flat[token_indices] += expert_output * gate_weights
return output_flat.view(B, T, D)
class Block(nn.Module):
@log
def __init__(self, config, layer_idx):
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
@ -283,6 +325,7 @@ class Block(nn.Module):
else:
self.mlp = MLP(config)
@log
def forward(self, x, cos_sin, kv_cache):
x = x + self.attn(norm(x), cos_sin, kv_cache)
x = x + self.mlp(norm(x))
@ -290,6 +333,7 @@ class Block(nn.Module):
class GPT(nn.Module):
@log
def __init__(self, config):
super().__init__()
self.config = config
@ -316,15 +360,16 @@ class GPT(nn.Module):
) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)
@log
def init_weights(self):
self.apply(self._init_weights)
# zero out classifier weights
torch.nn.init.zeros_(self.lm_head.weight)
# zero out c_proj weights in all blocks
for block in self.transformer.h:
if hasattr(block.mlp, 'c_proj'): # Regular MLP
if hasattr(block.mlp, "c_proj"): # Regular MLP
torch.nn.init.zeros_(block.mlp.c_proj.weight)
elif hasattr(block.mlp, 'experts'): # MoE
elif hasattr(block.mlp, "experts"): # MoE
for expert in block.mlp.experts:
torch.nn.init.zeros_(expert.c_proj.weight)
torch.nn.init.zeros_(block.attn.c_proj.weight)
@ -336,6 +381,7 @@ class GPT(nn.Module):
if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16)
@log
def _init_weights(self, module):
if isinstance(module, nn.Linear):
# https://arxiv.org/pdf/2310.17813
@ -349,6 +395,7 @@ class GPT(nn.Module):
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
# TODO: bump base theta more, e.g. 100K is more common more recently
@log
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# autodetect the device from model embeddings
if device is None:
@ -368,9 +415,11 @@ class GPT(nn.Module):
) # add batch and head dims for later broadcasting
return cos, sin
@log
def get_device(self):
return self.transformer.wte.weight.device
@log
def estimate_flops(self):
"""Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311"""
nparams = sum(p.numel() for p in self.parameters())
@ -384,6 +433,7 @@ class GPT(nn.Module):
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
return num_flops_per_token
@log
def setup_optimizers(
self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0
):
@ -421,6 +471,7 @@ class GPT(nn.Module):
group["initial_lr"] = group["lr"]
return optimizers
@log
def forward(self, idx, targets=None, kv_cache=None, loss_reduction="mean"):
B, T = idx.size()
@ -467,6 +518,7 @@ class GPT(nn.Module):
logits = softcap * torch.tanh(logits / softcap) # logits softcap
return logits
@log
@torch.inference_mode()
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
"""