Instrument main gpt model with logging

This commit is contained in:
Wollaston 2025-11-30 14:11:33 -05:00
parent 8a40915246
commit 62cfe4d4c3

View File

@ -12,16 +12,17 @@ Notable features:
""" """
import math import math
from functools import partial
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from logger import log
from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW from nanochat.adamw import DistAdamW
from nanochat.common import get_dist_info, print0
from nanochat.muon import DistMuon, Muon
from nanochat.vsa import HRROperations from nanochat.vsa import HRROperations
@ -40,12 +41,13 @@ class GPTConfig:
d_vsa: int = 256 d_vsa: int = 256
@log
def norm(x): def norm(x):
# Purely functional rmsnorm with no learnable params # Purely functional rmsnorm with no learnable params
return F.rms_norm(x, (x.size(-1),)) return F.rms_norm(x, (x.size(-1),))
@log
def apply_rotary_emb(x, cos, sin): def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2 d = x.shape[3] // 2
@ -58,6 +60,7 @@ def apply_rotary_emb(x, cos, sin):
class CausalSelfAttention(nn.Module): class CausalSelfAttention(nn.Module):
@log
def __init__(self, config, layer_idx): def __init__(self, config, layer_idx):
super().__init__() super().__init__()
self.layer_idx = layer_idx self.layer_idx = layer_idx
@ -72,6 +75,7 @@ class CausalSelfAttention(nn.Module):
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
@log
def forward(self, x, cos_sin, kv_cache): def forward(self, x, cos_sin, kv_cache):
B, T, C = x.size() B, T, C = x.size()
@ -112,15 +116,11 @@ class CausalSelfAttention(nn.Module):
if kv_cache is None or Tq == Tk: if kv_cache is None or Tq == Tk:
# During training (no KV cache), attend as usual with causal attention # During training (no KV cache), attend as usual with causal attention
# And even if there is KV cache, we can still use this simple version when Tq == Tk # And even if there is KV cache, we can still use this simple version when Tq == Tk
y = F.scaled_dot_product_attention( y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
q, k, v, is_causal=True
)
elif Tq == 1: elif Tq == 1:
# During inference but with a single query in this forward pass: # During inference but with a single query in this forward pass:
# The query has to attend to all the keys/values in the cache # The query has to attend to all the keys/values in the cache
y = F.scaled_dot_product_attention( y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
q, k, v, is_causal=False
)
else: else:
# During inference AND we have a chunk of queries in this forward pass: # During inference AND we have a chunk of queries in this forward pass:
# First, each query attends to all the cached keys/values (i.e. full prefix) # First, each query attends to all the cached keys/values (i.e. full prefix)
@ -134,9 +134,7 @@ class CausalSelfAttention(nn.Module):
attn_mask[:, prefix_len:] = torch.tril( attn_mask[:, prefix_len:] = torch.tril(
torch.ones((Tq, Tq), dtype=torch.bool, device=q.device) torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)
) )
y = F.scaled_dot_product_attention( y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
q, k, v, attn_mask=attn_mask
)
# Re-assemble the heads side by side and project back to residual stream # Re-assemble the heads side by side and project back to residual stream
y = y.transpose(1, 2).contiguous().view(B, T, -1) y = y.transpose(1, 2).contiguous().view(B, T, -1)
@ -145,18 +143,22 @@ class CausalSelfAttention(nn.Module):
class MLP(nn.Module): class MLP(nn.Module):
@log
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
@log
def forward(self, x): def forward(self, x):
x = self.c_fc(x) x = self.c_fc(x)
x = F.relu(x).square() x = F.relu(x).square()
x = self.c_proj(x) x = self.c_proj(x)
return x return x
class Router(nn.Module): class Router(nn.Module):
@log
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.num_experts = config.num_experts self.num_experts = config.num_experts
@ -173,19 +175,26 @@ class Router(nn.Module):
# Project tokens to VSA space # Project tokens to VSA space
self.token_to_vsa = nn.Linear(config.n_embd, self.d_vsa, bias=False) self.token_to_vsa = nn.Linear(config.n_embd, self.d_vsa, bias=False)
# Fixed expert signatures in VSA space (not learnable) # Fixed expert signatures in VSA space (not learnable)
self.register_buffer('expert_signatures', torch.randn(self.num_experts, self.d_vsa)) self.register_buffer(
"expert_signatures", torch.randn(self.num_experts, self.d_vsa)
)
# Single VSA superposition cache containing all token-expert bindings # Single VSA superposition cache containing all token-expert bindings
self.register_buffer('vsa_cache', torch.zeros(self.d_vsa)) self.register_buffer("vsa_cache", torch.zeros(self.d_vsa))
# Combine normal and VSA routing # Combine normal and VSA routing
self.vsa_weight = nn.Parameter(torch.tensor(0.5)) # Learnable weight for VSA contribution self.vsa_weight = nn.Parameter(
torch.tensor(0.5)
) # Learnable weight for VSA contribution
@log
def forward(self, x): def forward(self, x):
# x: (B, T, D) # x: (B, T, D)
B, T, D = x.size() B, T, D = x.size()
# Normal routing to determine which experts are activated # Normal routing to determine which experts are activated
normal_logits = self.w_gate(x) # (B, T, num_experts) normal_logits = self.w_gate(x) # (B, T, num_experts)
topk_logits, topk_indices = torch.topk(normal_logits, self.top_k, dim=-1) # (B, T, top_k) topk_logits, topk_indices = torch.topk(
normal_logits, self.top_k, dim=-1
) # (B, T, top_k)
if self.use_vsa: if self.use_vsa:
# Project tokens to VSA space # Project tokens to VSA space
@ -196,7 +205,9 @@ class Router(nn.Module):
if self.training: if self.training:
# Vectorized cache update # Vectorized cache update
# Flatten indices and tokens for batch processing # Flatten indices and tokens for batch processing
flat_tokens = token_vsa.unsqueeze(2).expand(-1, -1, self.top_k, -1) # (B, T, top_k, d_vsa) flat_tokens = token_vsa.unsqueeze(2).expand(
-1, -1, self.top_k, -1
) # (B, T, top_k, d_vsa)
flat_tokens = flat_tokens.reshape(-1, self.d_vsa) # (B*T*top_k, d_vsa) flat_tokens = flat_tokens.reshape(-1, self.d_vsa) # (B*T*top_k, d_vsa)
flat_indices = topk_indices.reshape(-1) # (B*T*top_k,) flat_indices = topk_indices.reshape(-1) # (B*T*top_k,)
@ -211,40 +222,66 @@ class Router(nn.Module):
self.vsa_cache = self.vsa_cache + cache_update.detach() self.vsa_cache = self.vsa_cache + cache_update.detach()
# Compute VSA routing scores by querying the single superposition # Compute VSA routing scores by querying the single superposition
vsa_scores = torch.zeros(B, T, self.num_experts, device=x.device, dtype=x.dtype) vsa_scores = torch.zeros(
B, T, self.num_experts, device=x.device, dtype=x.dtype
)
if torch.norm(self.vsa_cache) > 1e-8: # Only if cache has content if torch.norm(self.vsa_cache) > 1e-8: # Only if cache has content
# Vectorized similarity computation # Vectorized similarity computation
# Expand tokens to (B, T, num_experts, d_vsa) and expert_signatures to (B, T, num_experts, d_vsa) # Expand tokens to (B, T, num_experts, d_vsa) and expert_signatures to (B, T, num_experts, d_vsa)
token_expanded = token_vsa.unsqueeze(2).expand(B, T, self.num_experts, self.d_vsa) # (B, T, num_experts, d_vsa) token_expanded = token_vsa.unsqueeze(2).expand(
expert_expanded = self.expert_signatures.unsqueeze(0).unsqueeze(0).expand(B, T, self.num_experts, self.d_vsa) # (B, T, num_experts, d_vsa) B, T, self.num_experts, self.d_vsa
) # (B, T, num_experts, d_vsa)
expert_expanded = (
self.expert_signatures.unsqueeze(0)
.unsqueeze(0)
.expand(B, T, self.num_experts, self.d_vsa)
) # (B, T, num_experts, d_vsa)
# Vectorized binding: bind all tokens with all expert signatures # Vectorized binding: bind all tokens with all expert signatures
queries = self.vsa_ops.bind(token_expanded.reshape(-1, self.d_vsa), expert_expanded.reshape(-1, self.d_vsa)) # (B*T*num_experts, d_vsa) queries = self.vsa_ops.bind(
queries = queries.reshape(B, T, self.num_experts, self.d_vsa) # (B, T, num_experts, d_vsa) token_expanded.reshape(-1, self.d_vsa),
expert_expanded.reshape(-1, self.d_vsa),
) # (B*T*num_experts, d_vsa)
queries = queries.reshape(
B, T, self.num_experts, self.d_vsa
) # (B, T, num_experts, d_vsa)
# Compute similarities with the cache for all queries at once # Compute similarities with the cache for all queries at once
# Flatten queries and compute all similarities in one batch # Flatten queries and compute all similarities in one batch
queries_flat = queries.reshape(-1, self.d_vsa) # (B*T*num_experts, d_vsa) queries_flat = queries.reshape(
cache_expanded = self.vsa_cache.unsqueeze(0).expand(queries_flat.size(0), -1) # (B*T*num_experts, d_vsa) -1, self.d_vsa
) # (B*T*num_experts, d_vsa)
cache_expanded = self.vsa_cache.unsqueeze(0).expand(
queries_flat.size(0), -1
) # (B*T*num_experts, d_vsa)
# Use VSA similarity operation for all queries # Use VSA similarity operation for all queries
similarities_flat = self.vsa_ops.similarity(queries_flat, cache_expanded) # (B*T*num_experts,) similarities_flat = self.vsa_ops.similarity(
queries_flat, cache_expanded
) # (B*T*num_experts,)
vsa_scores = similarities_flat.reshape(B, T, self.num_experts) vsa_scores = similarities_flat.reshape(B, T, self.num_experts)
# Combine normal and VSA routing with learnable weight # Combine normal and VSA routing with learnable weight
combined_logits = normal_logits + torch.sigmoid(self.vsa_weight) * vsa_scores combined_logits = (
topk_logits, topk_indices = torch.topk(combined_logits, self.top_k, dim=-1) # (B, T, top_k) normal_logits + torch.sigmoid(self.vsa_weight) * vsa_scores
)
topk_logits, topk_indices = torch.topk(
combined_logits, self.top_k, dim=-1
) # (B, T, top_k)
topk_gates = F.softmax(topk_logits, dim=-1) # (B, T, top_k) topk_gates = F.softmax(topk_logits, dim=-1) # (B, T, top_k)
return topk_indices, topk_gates # both are (B, T, top_k) return topk_indices, topk_gates # both are (B, T, top_k)
class MoE(nn.Module): class MoE(nn.Module):
@log
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.num_experts = config.num_experts self.num_experts = config.num_experts
self.experts = nn.ModuleList([MLP(config) for _ in range(self.num_experts)]) self.experts = nn.ModuleList([MLP(config) for _ in range(self.num_experts)])
self.router = Router(config) self.router = Router(config)
@log
def forward(self, x): def forward(self, x):
# x: (B, T, D) # x: (B, T, D)
B, T, D = x.size() B, T, D = x.size()
@ -259,14 +296,18 @@ class MoE(nn.Module):
for i in range(self.num_experts): for i in range(self.num_experts):
# Find which tokens use expert i and their corresponding gate values # Find which tokens use expert i and their corresponding gate values
expert_mask = (topk_indices_flat == i) # (B*T, top_k) expert_mask = topk_indices_flat == i # (B*T, top_k)
if expert_mask.any(): if expert_mask.any():
# Get tokens that route to this expert # Get tokens that route to this expert
token_indices, k_indices = torch.where(expert_mask) token_indices, k_indices = torch.where(expert_mask)
if len(token_indices) > 0: if len(token_indices) > 0:
selected_tokens = x_flat[token_indices] # (num_selected, D) selected_tokens = x_flat[token_indices] # (num_selected, D)
expert_output = self.experts[i](selected_tokens) # (num_selected, D) expert_output = self.experts[i](
gate_weights = topk_gates_flat[token_indices, k_indices].unsqueeze(-1) # (num_selected, 1) selected_tokens
) # (num_selected, D)
gate_weights = topk_gates_flat[token_indices, k_indices].unsqueeze(
-1
) # (num_selected, 1)
# Add weighted expert output back to the corresponding tokens # Add weighted expert output back to the corresponding tokens
output_flat[token_indices] += expert_output * gate_weights output_flat[token_indices] += expert_output * gate_weights
@ -275,6 +316,7 @@ class MoE(nn.Module):
class Block(nn.Module): class Block(nn.Module):
@log
def __init__(self, config, layer_idx): def __init__(self, config, layer_idx):
super().__init__() super().__init__()
self.attn = CausalSelfAttention(config, layer_idx) self.attn = CausalSelfAttention(config, layer_idx)
@ -283,6 +325,7 @@ class Block(nn.Module):
else: else:
self.mlp = MLP(config) self.mlp = MLP(config)
@log
def forward(self, x, cos_sin, kv_cache): def forward(self, x, cos_sin, kv_cache):
x = x + self.attn(norm(x), cos_sin, kv_cache) x = x + self.attn(norm(x), cos_sin, kv_cache)
x = x + self.mlp(norm(x)) x = x + self.mlp(norm(x))
@ -290,6 +333,7 @@ class Block(nn.Module):
class GPT(nn.Module): class GPT(nn.Module):
@log
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
@ -316,15 +360,16 @@ class GPT(nn.Module):
) # persistent=False means it's not saved to the checkpoint ) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False) self.register_buffer("sin", sin, persistent=False)
@log
def init_weights(self): def init_weights(self):
self.apply(self._init_weights) self.apply(self._init_weights)
# zero out classifier weights # zero out classifier weights
torch.nn.init.zeros_(self.lm_head.weight) torch.nn.init.zeros_(self.lm_head.weight)
# zero out c_proj weights in all blocks # zero out c_proj weights in all blocks
for block in self.transformer.h: for block in self.transformer.h:
if hasattr(block.mlp, 'c_proj'): # Regular MLP if hasattr(block.mlp, "c_proj"): # Regular MLP
torch.nn.init.zeros_(block.mlp.c_proj.weight) torch.nn.init.zeros_(block.mlp.c_proj.weight)
elif hasattr(block.mlp, 'experts'): # MoE elif hasattr(block.mlp, "experts"): # MoE
for expert in block.mlp.experts: for expert in block.mlp.experts:
torch.nn.init.zeros_(expert.c_proj.weight) torch.nn.init.zeros_(expert.c_proj.weight)
torch.nn.init.zeros_(block.attn.c_proj.weight) torch.nn.init.zeros_(block.attn.c_proj.weight)
@ -336,6 +381,7 @@ class GPT(nn.Module):
if self.transformer.wte.weight.device.type == "cuda": if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16) self.transformer.wte.to(dtype=torch.bfloat16)
@log
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
# https://arxiv.org/pdf/2310.17813 # https://arxiv.org/pdf/2310.17813
@ -349,6 +395,7 @@ class GPT(nn.Module):
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0) torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
# TODO: bump base theta more, e.g. 100K is more common more recently # TODO: bump base theta more, e.g. 100K is more common more recently
@log
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# autodetect the device from model embeddings # autodetect the device from model embeddings
if device is None: if device is None:
@ -368,9 +415,11 @@ class GPT(nn.Module):
) # add batch and head dims for later broadcasting ) # add batch and head dims for later broadcasting
return cos, sin return cos, sin
@log
def get_device(self): def get_device(self):
return self.transformer.wte.weight.device return self.transformer.wte.weight.device
@log
def estimate_flops(self): def estimate_flops(self):
"""Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311""" """Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311"""
nparams = sum(p.numel() for p in self.parameters()) nparams = sum(p.numel() for p in self.parameters())
@ -384,6 +433,7 @@ class GPT(nn.Module):
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
return num_flops_per_token return num_flops_per_token
@log
def setup_optimizers( def setup_optimizers(
self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0 self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0
): ):
@ -421,6 +471,7 @@ class GPT(nn.Module):
group["initial_lr"] = group["lr"] group["initial_lr"] = group["lr"]
return optimizers return optimizers
@log
def forward(self, idx, targets=None, kv_cache=None, loss_reduction="mean"): def forward(self, idx, targets=None, kv_cache=None, loss_reduction="mean"):
B, T = idx.size() B, T = idx.size()
@ -467,6 +518,7 @@ class GPT(nn.Module):
logits = softcap * torch.tanh(logits / softcap) # logits softcap logits = softcap * torch.tanh(logits / softcap) # logits softcap
return logits return logits
@log
@torch.inference_mode() @torch.inference_mode()
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42): def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
""" """