From 62cfe4d4c3e41ffb35ae4c4f262c42e530671ad5 Mon Sep 17 00:00:00 2001 From: Wollaston Date: Sun, 30 Nov 2025 14:11:33 -0500 Subject: [PATCH] Instrument main gpt model with logging --- nanochat/gpt.py | 160 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 106 insertions(+), 54 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index cde6745..506a5fb 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -12,16 +12,17 @@ Notable features: """ import math -from functools import partial from dataclasses import dataclass +from functools import partial import torch import torch.nn as nn import torch.nn.functional as F +from logger import log -from nanochat.common import get_dist_info, print0 -from nanochat.muon import Muon, DistMuon from nanochat.adamw import DistAdamW +from nanochat.common import get_dist_info, print0 +from nanochat.muon import DistMuon, Muon from nanochat.vsa import HRROperations @@ -40,12 +41,13 @@ class GPTConfig: d_vsa: int = 256 - +@log def norm(x): # Purely functional rmsnorm with no learnable params return F.rms_norm(x, (x.size(-1),)) +@log def apply_rotary_emb(x, cos, sin): assert x.ndim == 4 # multihead attention d = x.shape[3] // 2 @@ -58,6 +60,7 @@ def apply_rotary_emb(x, cos, sin): class CausalSelfAttention(nn.Module): + @log def __init__(self, config, layer_idx): super().__init__() self.layer_idx = layer_idx @@ -72,6 +75,7 @@ class CausalSelfAttention(nn.Module): self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) + @log def forward(self, x, cos_sin, kv_cache): B, T, C = x.size() @@ -108,19 +112,15 @@ class CausalSelfAttention(nn.Module): repeat_factor = self.n_head // self.n_kv_head k = k.repeat_interleave(repeat_factor, dim=1) v = v.repeat_interleave(repeat_factor, dim=1) - + if kv_cache is None or Tq == Tk: # During training (no KV cache), attend as usual with causal attention # And even if there is KV cache, we can still use this simple version when Tq == Tk - y = F.scaled_dot_product_attention( - q, k, v, is_causal=True - ) + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) elif Tq == 1: # During inference but with a single query in this forward pass: # The query has to attend to all the keys/values in the cache - y = F.scaled_dot_product_attention( - q, k, v, is_causal=False - ) + y = F.scaled_dot_product_attention(q, k, v, is_causal=False) else: # During inference AND we have a chunk of queries in this forward pass: # First, each query attends to all the cached keys/values (i.e. full prefix) @@ -134,9 +134,7 @@ class CausalSelfAttention(nn.Module): attn_mask[:, prefix_len:] = torch.tril( torch.ones((Tq, Tq), dtype=torch.bool, device=q.device) ) - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask - ) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) # Re-assemble the heads side by side and project back to residual stream y = y.transpose(1, 2).contiguous().view(B, T, -1) @@ -145,136 +143,180 @@ class CausalSelfAttention(nn.Module): class MLP(nn.Module): + @log def __init__(self, config): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) + @log def forward(self, x): x = self.c_fc(x) x = F.relu(x).square() x = self.c_proj(x) return x - + + class Router(nn.Module): + @log def __init__(self, config): super().__init__() self.num_experts = config.num_experts self.top_k = config.top_k self.use_vsa = config.use_vsa self.d_vsa = config.d_vsa - + # Normal routing gate self.w_gate = nn.Linear(config.n_embd, self.num_experts, bias=False) - + # VSA routing components if self.use_vsa: self.vsa_ops = HRROperations() # Project tokens to VSA space self.token_to_vsa = nn.Linear(config.n_embd, self.d_vsa, bias=False) # Fixed expert signatures in VSA space (not learnable) - self.register_buffer('expert_signatures', torch.randn(self.num_experts, self.d_vsa)) + self.register_buffer( + "expert_signatures", torch.randn(self.num_experts, self.d_vsa) + ) # Single VSA superposition cache containing all token-expert bindings - self.register_buffer('vsa_cache', torch.zeros(self.d_vsa)) + self.register_buffer("vsa_cache", torch.zeros(self.d_vsa)) # Combine normal and VSA routing - self.vsa_weight = nn.Parameter(torch.tensor(0.5)) # Learnable weight for VSA contribution + self.vsa_weight = nn.Parameter( + torch.tensor(0.5) + ) # Learnable weight for VSA contribution + @log def forward(self, x): # x: (B, T, D) B, T, D = x.size() - + # Normal routing to determine which experts are activated normal_logits = self.w_gate(x) # (B, T, num_experts) - topk_logits, topk_indices = torch.topk(normal_logits, self.top_k, dim=-1) # (B, T, top_k) - + topk_logits, topk_indices = torch.topk( + normal_logits, self.top_k, dim=-1 + ) # (B, T, top_k) + if self.use_vsa: # Project tokens to VSA space token_vsa = self.token_to_vsa(x) # (B, T, d_vsa) - + # Update VSA cache: bind activated tokens to their expert signatures and add to single superposition # Only update cache during training mode if self.training: # Vectorized cache update # Flatten indices and tokens for batch processing - flat_tokens = token_vsa.unsqueeze(2).expand(-1, -1, self.top_k, -1) # (B, T, top_k, d_vsa) + flat_tokens = token_vsa.unsqueeze(2).expand( + -1, -1, self.top_k, -1 + ) # (B, T, top_k, d_vsa) flat_tokens = flat_tokens.reshape(-1, self.d_vsa) # (B*T*top_k, d_vsa) flat_indices = topk_indices.reshape(-1) # (B*T*top_k,) - + # Get corresponding expert signatures expert_sigs = self.expert_signatures[flat_indices] # (B*T*top_k, d_vsa) - + # Vectorized binding using VSA operations bound_tokens = self.vsa_ops.bind(flat_tokens, expert_sigs) - + # Sum all bound tokens and add to cache cache_update = torch.sum(bound_tokens, dim=0) self.vsa_cache = self.vsa_cache + cache_update.detach() - + # Compute VSA routing scores by querying the single superposition - vsa_scores = torch.zeros(B, T, self.num_experts, device=x.device, dtype=x.dtype) + vsa_scores = torch.zeros( + B, T, self.num_experts, device=x.device, dtype=x.dtype + ) if torch.norm(self.vsa_cache) > 1e-8: # Only if cache has content # Vectorized similarity computation # Expand tokens to (B, T, num_experts, d_vsa) and expert_signatures to (B, T, num_experts, d_vsa) - token_expanded = token_vsa.unsqueeze(2).expand(B, T, self.num_experts, self.d_vsa) # (B, T, num_experts, d_vsa) - expert_expanded = self.expert_signatures.unsqueeze(0).unsqueeze(0).expand(B, T, self.num_experts, self.d_vsa) # (B, T, num_experts, d_vsa) - + token_expanded = token_vsa.unsqueeze(2).expand( + B, T, self.num_experts, self.d_vsa + ) # (B, T, num_experts, d_vsa) + expert_expanded = ( + self.expert_signatures.unsqueeze(0) + .unsqueeze(0) + .expand(B, T, self.num_experts, self.d_vsa) + ) # (B, T, num_experts, d_vsa) + # Vectorized binding: bind all tokens with all expert signatures - queries = self.vsa_ops.bind(token_expanded.reshape(-1, self.d_vsa), expert_expanded.reshape(-1, self.d_vsa)) # (B*T*num_experts, d_vsa) - queries = queries.reshape(B, T, self.num_experts, self.d_vsa) # (B, T, num_experts, d_vsa) - + queries = self.vsa_ops.bind( + token_expanded.reshape(-1, self.d_vsa), + expert_expanded.reshape(-1, self.d_vsa), + ) # (B*T*num_experts, d_vsa) + queries = queries.reshape( + B, T, self.num_experts, self.d_vsa + ) # (B, T, num_experts, d_vsa) + # Compute similarities with the cache for all queries at once # Flatten queries and compute all similarities in one batch - queries_flat = queries.reshape(-1, self.d_vsa) # (B*T*num_experts, d_vsa) - cache_expanded = self.vsa_cache.unsqueeze(0).expand(queries_flat.size(0), -1) # (B*T*num_experts, d_vsa) - + queries_flat = queries.reshape( + -1, self.d_vsa + ) # (B*T*num_experts, d_vsa) + cache_expanded = self.vsa_cache.unsqueeze(0).expand( + queries_flat.size(0), -1 + ) # (B*T*num_experts, d_vsa) + # Use VSA similarity operation for all queries - similarities_flat = self.vsa_ops.similarity(queries_flat, cache_expanded) # (B*T*num_experts,) + similarities_flat = self.vsa_ops.similarity( + queries_flat, cache_expanded + ) # (B*T*num_experts,) vsa_scores = similarities_flat.reshape(B, T, self.num_experts) - + # Combine normal and VSA routing with learnable weight - combined_logits = normal_logits + torch.sigmoid(self.vsa_weight) * vsa_scores - topk_logits, topk_indices = torch.topk(combined_logits, self.top_k, dim=-1) # (B, T, top_k) - + combined_logits = ( + normal_logits + torch.sigmoid(self.vsa_weight) * vsa_scores + ) + topk_logits, topk_indices = torch.topk( + combined_logits, self.top_k, dim=-1 + ) # (B, T, top_k) + topk_gates = F.softmax(topk_logits, dim=-1) # (B, T, top_k) return topk_indices, topk_gates # both are (B, T, top_k) - + + class MoE(nn.Module): + @log def __init__(self, config): super().__init__() self.num_experts = config.num_experts self.experts = nn.ModuleList([MLP(config) for _ in range(self.num_experts)]) self.router = Router(config) + @log def forward(self, x): # x: (B, T, D) B, T, D = x.size() topk_indices, topk_gates = self.router(x) # both are (B, T, top_k) output = torch.zeros_like(x) - + # Flatten for easier token-level processing x_flat = x.view(-1, D) # (B*T, D) output_flat = output.view(-1, D) # (B*T, D) topk_indices_flat = topk_indices.view(-1, topk_indices.size(-1)) # (B*T, top_k) topk_gates_flat = topk_gates.view(-1, topk_gates.size(-1)) # (B*T, top_k) - + for i in range(self.num_experts): # Find which tokens use expert i and their corresponding gate values - expert_mask = (topk_indices_flat == i) # (B*T, top_k) + expert_mask = topk_indices_flat == i # (B*T, top_k) if expert_mask.any(): # Get tokens that route to this expert token_indices, k_indices = torch.where(expert_mask) if len(token_indices) > 0: selected_tokens = x_flat[token_indices] # (num_selected, D) - expert_output = self.experts[i](selected_tokens) # (num_selected, D) - gate_weights = topk_gates_flat[token_indices, k_indices].unsqueeze(-1) # (num_selected, 1) - + expert_output = self.experts[i]( + selected_tokens + ) # (num_selected, D) + gate_weights = topk_gates_flat[token_indices, k_indices].unsqueeze( + -1 + ) # (num_selected, 1) + # Add weighted expert output back to the corresponding tokens output_flat[token_indices] += expert_output * gate_weights - + return output_flat.view(B, T, D) class Block(nn.Module): + @log def __init__(self, config, layer_idx): super().__init__() self.attn = CausalSelfAttention(config, layer_idx) @@ -283,6 +325,7 @@ class Block(nn.Module): else: self.mlp = MLP(config) + @log def forward(self, x, cos_sin, kv_cache): x = x + self.attn(norm(x), cos_sin, kv_cache) x = x + self.mlp(norm(x)) @@ -290,6 +333,7 @@ class Block(nn.Module): class GPT(nn.Module): + @log def __init__(self, config): super().__init__() self.config = config @@ -316,15 +360,16 @@ class GPT(nn.Module): ) # persistent=False means it's not saved to the checkpoint self.register_buffer("sin", sin, persistent=False) + @log def init_weights(self): self.apply(self._init_weights) # zero out classifier weights torch.nn.init.zeros_(self.lm_head.weight) # zero out c_proj weights in all blocks for block in self.transformer.h: - if hasattr(block.mlp, 'c_proj'): # Regular MLP + if hasattr(block.mlp, "c_proj"): # Regular MLP torch.nn.init.zeros_(block.mlp.c_proj.weight) - elif hasattr(block.mlp, 'experts'): # MoE + elif hasattr(block.mlp, "experts"): # MoE for expert in block.mlp.experts: torch.nn.init.zeros_(expert.c_proj.weight) torch.nn.init.zeros_(block.attn.c_proj.weight) @@ -336,6 +381,7 @@ class GPT(nn.Module): if self.transformer.wte.weight.device.type == "cuda": self.transformer.wte.to(dtype=torch.bfloat16) + @log def _init_weights(self, module): if isinstance(module, nn.Linear): # https://arxiv.org/pdf/2310.17813 @@ -349,6 +395,7 @@ class GPT(nn.Module): torch.nn.init.normal_(module.weight, mean=0.0, std=1.0) # TODO: bump base theta more, e.g. 100K is more common more recently + @log def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): # autodetect the device from model embeddings if device is None: @@ -368,9 +415,11 @@ class GPT(nn.Module): ) # add batch and head dims for later broadcasting return cos, sin + @log def get_device(self): return self.transformer.wte.weight.device + @log def estimate_flops(self): """Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311""" nparams = sum(p.numel() for p in self.parameters()) @@ -384,6 +433,7 @@ class GPT(nn.Module): num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t return num_flops_per_token + @log def setup_optimizers( self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0 ): @@ -421,6 +471,7 @@ class GPT(nn.Module): group["initial_lr"] = group["lr"] return optimizers + @log def forward(self, idx, targets=None, kv_cache=None, loss_reduction="mean"): B, T = idx.size() @@ -467,6 +518,7 @@ class GPT(nn.Module): logits = softcap * torch.tanh(logits / softcap) # logits softcap return logits + @log @torch.inference_mode() def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42): """