mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Instrument main gpt model with logging
This commit is contained in:
parent
8a40915246
commit
62cfe4d4c3
160
nanochat/gpt.py
160
nanochat/gpt.py
|
|
@ -12,16 +12,17 @@ Notable features:
|
|||
"""
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from logger import log
|
||||
|
||||
from nanochat.common import get_dist_info, print0
|
||||
from nanochat.muon import Muon, DistMuon
|
||||
from nanochat.adamw import DistAdamW
|
||||
from nanochat.common import get_dist_info, print0
|
||||
from nanochat.muon import DistMuon, Muon
|
||||
from nanochat.vsa import HRROperations
|
||||
|
||||
|
||||
|
|
@ -40,12 +41,13 @@ class GPTConfig:
|
|||
d_vsa: int = 256
|
||||
|
||||
|
||||
|
||||
@log
|
||||
def norm(x):
|
||||
# Purely functional rmsnorm with no learnable params
|
||||
return F.rms_norm(x, (x.size(-1),))
|
||||
|
||||
|
||||
@log
|
||||
def apply_rotary_emb(x, cos, sin):
|
||||
assert x.ndim == 4 # multihead attention
|
||||
d = x.shape[3] // 2
|
||||
|
|
@ -58,6 +60,7 @@ def apply_rotary_emb(x, cos, sin):
|
|||
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
@log
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
|
|
@ -72,6 +75,7 @@ class CausalSelfAttention(nn.Module):
|
|||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||
|
||||
@log
|
||||
def forward(self, x, cos_sin, kv_cache):
|
||||
B, T, C = x.size()
|
||||
|
||||
|
|
@ -108,19 +112,15 @@ class CausalSelfAttention(nn.Module):
|
|||
repeat_factor = self.n_head // self.n_kv_head
|
||||
k = k.repeat_interleave(repeat_factor, dim=1)
|
||||
v = v.repeat_interleave(repeat_factor, dim=1)
|
||||
|
||||
|
||||
if kv_cache is None or Tq == Tk:
|
||||
# During training (no KV cache), attend as usual with causal attention
|
||||
# And even if there is KV cache, we can still use this simple version when Tq == Tk
|
||||
y = F.scaled_dot_product_attention(
|
||||
q, k, v, is_causal=True
|
||||
)
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
elif Tq == 1:
|
||||
# During inference but with a single query in this forward pass:
|
||||
# The query has to attend to all the keys/values in the cache
|
||||
y = F.scaled_dot_product_attention(
|
||||
q, k, v, is_causal=False
|
||||
)
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
|
||||
else:
|
||||
# During inference AND we have a chunk of queries in this forward pass:
|
||||
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
||||
|
|
@ -134,9 +134,7 @@ class CausalSelfAttention(nn.Module):
|
|||
attn_mask[:, prefix_len:] = torch.tril(
|
||||
torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)
|
||||
)
|
||||
y = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask
|
||||
)
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
|
||||
# Re-assemble the heads side by side and project back to residual stream
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
||||
|
|
@ -145,136 +143,180 @@ class CausalSelfAttention(nn.Module):
|
|||
|
||||
|
||||
class MLP(nn.Module):
|
||||
@log
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
||||
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
||||
|
||||
@log
|
||||
def forward(self, x):
|
||||
x = self.c_fc(x)
|
||||
x = F.relu(x).square()
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class Router(nn.Module):
|
||||
@log
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_experts
|
||||
self.top_k = config.top_k
|
||||
self.use_vsa = config.use_vsa
|
||||
self.d_vsa = config.d_vsa
|
||||
|
||||
|
||||
# Normal routing gate
|
||||
self.w_gate = nn.Linear(config.n_embd, self.num_experts, bias=False)
|
||||
|
||||
|
||||
# VSA routing components
|
||||
if self.use_vsa:
|
||||
self.vsa_ops = HRROperations()
|
||||
# Project tokens to VSA space
|
||||
self.token_to_vsa = nn.Linear(config.n_embd, self.d_vsa, bias=False)
|
||||
# Fixed expert signatures in VSA space (not learnable)
|
||||
self.register_buffer('expert_signatures', torch.randn(self.num_experts, self.d_vsa))
|
||||
self.register_buffer(
|
||||
"expert_signatures", torch.randn(self.num_experts, self.d_vsa)
|
||||
)
|
||||
# Single VSA superposition cache containing all token-expert bindings
|
||||
self.register_buffer('vsa_cache', torch.zeros(self.d_vsa))
|
||||
self.register_buffer("vsa_cache", torch.zeros(self.d_vsa))
|
||||
# Combine normal and VSA routing
|
||||
self.vsa_weight = nn.Parameter(torch.tensor(0.5)) # Learnable weight for VSA contribution
|
||||
self.vsa_weight = nn.Parameter(
|
||||
torch.tensor(0.5)
|
||||
) # Learnable weight for VSA contribution
|
||||
|
||||
@log
|
||||
def forward(self, x):
|
||||
# x: (B, T, D)
|
||||
B, T, D = x.size()
|
||||
|
||||
|
||||
# Normal routing to determine which experts are activated
|
||||
normal_logits = self.w_gate(x) # (B, T, num_experts)
|
||||
topk_logits, topk_indices = torch.topk(normal_logits, self.top_k, dim=-1) # (B, T, top_k)
|
||||
|
||||
topk_logits, topk_indices = torch.topk(
|
||||
normal_logits, self.top_k, dim=-1
|
||||
) # (B, T, top_k)
|
||||
|
||||
if self.use_vsa:
|
||||
# Project tokens to VSA space
|
||||
token_vsa = self.token_to_vsa(x) # (B, T, d_vsa)
|
||||
|
||||
|
||||
# Update VSA cache: bind activated tokens to their expert signatures and add to single superposition
|
||||
# Only update cache during training mode
|
||||
if self.training:
|
||||
# Vectorized cache update
|
||||
# Flatten indices and tokens for batch processing
|
||||
flat_tokens = token_vsa.unsqueeze(2).expand(-1, -1, self.top_k, -1) # (B, T, top_k, d_vsa)
|
||||
flat_tokens = token_vsa.unsqueeze(2).expand(
|
||||
-1, -1, self.top_k, -1
|
||||
) # (B, T, top_k, d_vsa)
|
||||
flat_tokens = flat_tokens.reshape(-1, self.d_vsa) # (B*T*top_k, d_vsa)
|
||||
flat_indices = topk_indices.reshape(-1) # (B*T*top_k,)
|
||||
|
||||
|
||||
# Get corresponding expert signatures
|
||||
expert_sigs = self.expert_signatures[flat_indices] # (B*T*top_k, d_vsa)
|
||||
|
||||
|
||||
# Vectorized binding using VSA operations
|
||||
bound_tokens = self.vsa_ops.bind(flat_tokens, expert_sigs)
|
||||
|
||||
|
||||
# Sum all bound tokens and add to cache
|
||||
cache_update = torch.sum(bound_tokens, dim=0)
|
||||
self.vsa_cache = self.vsa_cache + cache_update.detach()
|
||||
|
||||
|
||||
# Compute VSA routing scores by querying the single superposition
|
||||
vsa_scores = torch.zeros(B, T, self.num_experts, device=x.device, dtype=x.dtype)
|
||||
vsa_scores = torch.zeros(
|
||||
B, T, self.num_experts, device=x.device, dtype=x.dtype
|
||||
)
|
||||
if torch.norm(self.vsa_cache) > 1e-8: # Only if cache has content
|
||||
# Vectorized similarity computation
|
||||
# Expand tokens to (B, T, num_experts, d_vsa) and expert_signatures to (B, T, num_experts, d_vsa)
|
||||
token_expanded = token_vsa.unsqueeze(2).expand(B, T, self.num_experts, self.d_vsa) # (B, T, num_experts, d_vsa)
|
||||
expert_expanded = self.expert_signatures.unsqueeze(0).unsqueeze(0).expand(B, T, self.num_experts, self.d_vsa) # (B, T, num_experts, d_vsa)
|
||||
|
||||
token_expanded = token_vsa.unsqueeze(2).expand(
|
||||
B, T, self.num_experts, self.d_vsa
|
||||
) # (B, T, num_experts, d_vsa)
|
||||
expert_expanded = (
|
||||
self.expert_signatures.unsqueeze(0)
|
||||
.unsqueeze(0)
|
||||
.expand(B, T, self.num_experts, self.d_vsa)
|
||||
) # (B, T, num_experts, d_vsa)
|
||||
|
||||
# Vectorized binding: bind all tokens with all expert signatures
|
||||
queries = self.vsa_ops.bind(token_expanded.reshape(-1, self.d_vsa), expert_expanded.reshape(-1, self.d_vsa)) # (B*T*num_experts, d_vsa)
|
||||
queries = queries.reshape(B, T, self.num_experts, self.d_vsa) # (B, T, num_experts, d_vsa)
|
||||
|
||||
queries = self.vsa_ops.bind(
|
||||
token_expanded.reshape(-1, self.d_vsa),
|
||||
expert_expanded.reshape(-1, self.d_vsa),
|
||||
) # (B*T*num_experts, d_vsa)
|
||||
queries = queries.reshape(
|
||||
B, T, self.num_experts, self.d_vsa
|
||||
) # (B, T, num_experts, d_vsa)
|
||||
|
||||
# Compute similarities with the cache for all queries at once
|
||||
# Flatten queries and compute all similarities in one batch
|
||||
queries_flat = queries.reshape(-1, self.d_vsa) # (B*T*num_experts, d_vsa)
|
||||
cache_expanded = self.vsa_cache.unsqueeze(0).expand(queries_flat.size(0), -1) # (B*T*num_experts, d_vsa)
|
||||
|
||||
queries_flat = queries.reshape(
|
||||
-1, self.d_vsa
|
||||
) # (B*T*num_experts, d_vsa)
|
||||
cache_expanded = self.vsa_cache.unsqueeze(0).expand(
|
||||
queries_flat.size(0), -1
|
||||
) # (B*T*num_experts, d_vsa)
|
||||
|
||||
# Use VSA similarity operation for all queries
|
||||
similarities_flat = self.vsa_ops.similarity(queries_flat, cache_expanded) # (B*T*num_experts,)
|
||||
similarities_flat = self.vsa_ops.similarity(
|
||||
queries_flat, cache_expanded
|
||||
) # (B*T*num_experts,)
|
||||
vsa_scores = similarities_flat.reshape(B, T, self.num_experts)
|
||||
|
||||
|
||||
# Combine normal and VSA routing with learnable weight
|
||||
combined_logits = normal_logits + torch.sigmoid(self.vsa_weight) * vsa_scores
|
||||
topk_logits, topk_indices = torch.topk(combined_logits, self.top_k, dim=-1) # (B, T, top_k)
|
||||
|
||||
combined_logits = (
|
||||
normal_logits + torch.sigmoid(self.vsa_weight) * vsa_scores
|
||||
)
|
||||
topk_logits, topk_indices = torch.topk(
|
||||
combined_logits, self.top_k, dim=-1
|
||||
) # (B, T, top_k)
|
||||
|
||||
topk_gates = F.softmax(topk_logits, dim=-1) # (B, T, top_k)
|
||||
return topk_indices, topk_gates # both are (B, T, top_k)
|
||||
|
||||
|
||||
|
||||
class MoE(nn.Module):
|
||||
@log
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_experts
|
||||
self.experts = nn.ModuleList([MLP(config) for _ in range(self.num_experts)])
|
||||
self.router = Router(config)
|
||||
|
||||
@log
|
||||
def forward(self, x):
|
||||
# x: (B, T, D)
|
||||
B, T, D = x.size()
|
||||
topk_indices, topk_gates = self.router(x) # both are (B, T, top_k)
|
||||
output = torch.zeros_like(x)
|
||||
|
||||
|
||||
# Flatten for easier token-level processing
|
||||
x_flat = x.view(-1, D) # (B*T, D)
|
||||
output_flat = output.view(-1, D) # (B*T, D)
|
||||
topk_indices_flat = topk_indices.view(-1, topk_indices.size(-1)) # (B*T, top_k)
|
||||
topk_gates_flat = topk_gates.view(-1, topk_gates.size(-1)) # (B*T, top_k)
|
||||
|
||||
|
||||
for i in range(self.num_experts):
|
||||
# Find which tokens use expert i and their corresponding gate values
|
||||
expert_mask = (topk_indices_flat == i) # (B*T, top_k)
|
||||
expert_mask = topk_indices_flat == i # (B*T, top_k)
|
||||
if expert_mask.any():
|
||||
# Get tokens that route to this expert
|
||||
token_indices, k_indices = torch.where(expert_mask)
|
||||
if len(token_indices) > 0:
|
||||
selected_tokens = x_flat[token_indices] # (num_selected, D)
|
||||
expert_output = self.experts[i](selected_tokens) # (num_selected, D)
|
||||
gate_weights = topk_gates_flat[token_indices, k_indices].unsqueeze(-1) # (num_selected, 1)
|
||||
|
||||
expert_output = self.experts[i](
|
||||
selected_tokens
|
||||
) # (num_selected, D)
|
||||
gate_weights = topk_gates_flat[token_indices, k_indices].unsqueeze(
|
||||
-1
|
||||
) # (num_selected, 1)
|
||||
|
||||
# Add weighted expert output back to the corresponding tokens
|
||||
output_flat[token_indices] += expert_output * gate_weights
|
||||
|
||||
|
||||
return output_flat.view(B, T, D)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
@log
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.attn = CausalSelfAttention(config, layer_idx)
|
||||
|
|
@ -283,6 +325,7 @@ class Block(nn.Module):
|
|||
else:
|
||||
self.mlp = MLP(config)
|
||||
|
||||
@log
|
||||
def forward(self, x, cos_sin, kv_cache):
|
||||
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
||||
x = x + self.mlp(norm(x))
|
||||
|
|
@ -290,6 +333,7 @@ class Block(nn.Module):
|
|||
|
||||
|
||||
class GPT(nn.Module):
|
||||
@log
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
|
@ -316,15 +360,16 @@ class GPT(nn.Module):
|
|||
) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
|
||||
@log
|
||||
def init_weights(self):
|
||||
self.apply(self._init_weights)
|
||||
# zero out classifier weights
|
||||
torch.nn.init.zeros_(self.lm_head.weight)
|
||||
# zero out c_proj weights in all blocks
|
||||
for block in self.transformer.h:
|
||||
if hasattr(block.mlp, 'c_proj'): # Regular MLP
|
||||
if hasattr(block.mlp, "c_proj"): # Regular MLP
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
elif hasattr(block.mlp, 'experts'): # MoE
|
||||
elif hasattr(block.mlp, "experts"): # MoE
|
||||
for expert in block.mlp.experts:
|
||||
torch.nn.init.zeros_(expert.c_proj.weight)
|
||||
torch.nn.init.zeros_(block.attn.c_proj.weight)
|
||||
|
|
@ -336,6 +381,7 @@ class GPT(nn.Module):
|
|||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
|
||||
@log
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
# https://arxiv.org/pdf/2310.17813
|
||||
|
|
@ -349,6 +395,7 @@ class GPT(nn.Module):
|
|||
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
|
||||
|
||||
# TODO: bump base theta more, e.g. 100K is more common more recently
|
||||
@log
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
# autodetect the device from model embeddings
|
||||
if device is None:
|
||||
|
|
@ -368,9 +415,11 @@ class GPT(nn.Module):
|
|||
) # add batch and head dims for later broadcasting
|
||||
return cos, sin
|
||||
|
||||
@log
|
||||
def get_device(self):
|
||||
return self.transformer.wte.weight.device
|
||||
|
||||
@log
|
||||
def estimate_flops(self):
|
||||
"""Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311"""
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
|
|
@ -384,6 +433,7 @@ class GPT(nn.Module):
|
|||
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
||||
return num_flops_per_token
|
||||
|
||||
@log
|
||||
def setup_optimizers(
|
||||
self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0
|
||||
):
|
||||
|
|
@ -421,6 +471,7 @@ class GPT(nn.Module):
|
|||
group["initial_lr"] = group["lr"]
|
||||
return optimizers
|
||||
|
||||
@log
|
||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction="mean"):
|
||||
B, T = idx.size()
|
||||
|
||||
|
|
@ -467,6 +518,7 @@ class GPT(nn.Module):
|
|||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
return logits
|
||||
|
||||
@log
|
||||
@torch.inference_mode()
|
||||
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user