to_hf adjusted to current imple

This commit is contained in:
Muheng 2026-01-06 06:34:46 +00:00
parent 952ea5137a
commit 8f1378235e

View File

@ -14,7 +14,6 @@ import argparse
import json
import os
import shutil
import sys
from dataclasses import fields as dataclass_fields
from pathlib import Path
from typing import Optional
@ -31,11 +30,7 @@ except ImportError as exc:
from nanochat.common import get_base_dir
from nanochat.tokenizer import RustBPETokenizer
# standard.py expects `manager` to be importable; alias the package module to satisfy that import.
import nanochat.manager as _manager
sys.modules.setdefault("manager", _manager)
from nanochat.standard import GPT, GPTConfig
from nanochat.gpt import GPT, GPTConfig
CHAT_TEMPLATE = (
"<|bos|>{% for message in messages %}{% if message['role'] == 'user' %}"
@ -168,7 +163,6 @@ def normalize_config(cfg_kwargs: dict) -> dict:
cfg = dict(cfg_kwargs)
if "sequence_len" in cfg and "block_size" not in cfg:
cfg["block_size"] = cfg.pop("sequence_len")
cfg.pop("n_kv_head", None)
return cfg
@ -722,6 +716,33 @@ import torch.nn as nn
import torch.nn.functional as F
class MOEManager:
def __init__(self):
self.aux_loss = []
self.router_z_loss = []
def add_aux_loss(self, loss):
self.aux_loss.append(loss)
def reset_aux_loss(self):
self.aux_loss = []
def add_router_z_loss(self, loss):
self.router_z_loss.append(loss)
def reset_router_z_loss(self):
self.router_z_loss = []
def aggregate_aux_loss(self):
return sum(self.aux_loss)
def aggregate_router_z_loss(self):
return sum(self.router_z_loss)
MANAGER = MOEManager()
class LayerNorm(nn.Module):
def __init__(self, ndim, bias):
super().__init__()
@ -777,12 +798,14 @@ class Router(nn.Module):
super().__init__()
self.top_k = config.top_k
self.n_exp = config.n_exp
assert 1 <= self.top_k <= config.n_exp
assert self.top_k >= 1 and self.top_k <= config.n_exp
self.use_noisy_top_k = config.use_noisy_top_k
self.train_capacity = config.train_capacity
self.eval_capacity = config.eval_capacity
self.min_capacity = config.min_capacity
self.router_use_full_prec = config.router_use_full_prec
self.use_aux_loss = config.use_aux_loss
self.use_router_z_loss = config.use_router_z_loss
self.w_g = nn.Linear(config.n_embd, config.n_exp, bias=False)
self.w_noise = nn.Linear(config.n_embd, config.n_exp, bias=False) if self.use_noisy_top_k else None
@ -797,10 +820,16 @@ class Router(nn.Module):
noise = F.softplus(self.w_noise(x))
noise *= torch.randn_like(noise)
logits += noise
if self.use_router_z_loss:
z_loss = self.compute_router_z_loss(logits)
MANAGER.add_router_z_loss(z_loss)
top_k_logits, top_k_indices = logits.topk(self.top_k, dim=-1)
router_probs = torch.full_like(logits, float("-inf"))
router_probs.scatter_(-1, top_k_indices, top_k_logits)
router_probs = F.softmax(router_probs, dim=-1)
if self.use_aux_loss:
aux_loss = self.compute_aux_loss(router_probs, top_k_indices)
MANAGER.add_aux_loss(aux_loss)
exp_capacity = self.get_capacity(num_tokens)
exp_mask = F.one_hot(top_k_indices, num_classes=self.n_exp)
exp_mask = exp_mask.view(num_tokens, self.top_k, self.n_exp)
@ -809,13 +838,26 @@ class Router(nn.Module):
exp_rank = torch.cumsum(exp_rank, dim=0) - 1
exp_rank = exp_rank.reshape(self.top_k, num_tokens, self.n_exp)
exp_mask *= torch.lt(exp_rank, exp_capacity)
used_capacity = torch.sum(exp_mask, dim=(0, 1))
exp_rank = torch.sum(exp_mask * exp_rank, dim=-1)
router_probs = router_probs.view(num_tokens, self.n_exp)[None, :]
exp_weights = exp_mask * router_probs
exp_rank_sc = F.one_hot(exp_rank, num_classes=exp_capacity)
cb_weight = torch.sum(exp_weights.unsqueeze(3) * exp_rank_sc.unsqueeze(2), dim=0)
sec_mask = cb_weight.bool()
return cb_weight, sec_mask
return used_capacity, cb_weight, sec_mask
def compute_aux_loss(self, expert_probs: torch.Tensor, indices: torch.Tensor):
with torch.no_grad():
one_hot_indices = F.one_hot(indices, num_classes=self.n_exp)
one_hot_indices = torch.sum(one_hot_indices.float(), dim=2)
tokens_per_expert = torch.mean(one_hot_indices.float(), dim=(0, 1))
prob_per_expert = torch.mean(expert_probs.float(), dim=(0, 1))
return self.n_exp * torch.sum(prob_per_expert * tokens_per_expert)
def compute_router_z_loss(self, logits: torch.Tensor):
z_loss = torch.logsumexp(logits, dim=-1) ** 2.0
return torch.mean(z_loss)
def get_capacity(self, tokens_per_batch):
capacity_factor = self.train_capacity if self.training else self.eval_capacity
@ -857,13 +899,13 @@ class MOELayer(nn.Module):
def forward(self, x):
B, T, n_embd = x.size()
num_tokens = B * T
exp_weight, exp_mask = self.router(x)
used_capacity, exp_weight, exp_mask = self.router(x)
x = x.view(num_tokens, n_embd)
exp_batches = exp_mask.permute(1, 2, 0).type_as(x) @ x
exp_out = self.experts(exp_batches)
exp_weight = exp_weight.view(num_tokens, -1)
exp_out = exp_out.view(-1, n_embd)
output = exp_weight.type_as(exp_out) @ exp_out
output = exp_weight @ exp_out
return output.view(B, T, n_embd)
@ -899,13 +941,14 @@ class Block(nn.Module):
@dataclass
class GPTConfig:
sequence_len: int = 1024
block_size: int = 1024
vocab_size: int = 50304
n_layer: int = 6
n_head: int = 6
n_kv_head: int = 6
n_embd: int = 384
dropout: float = 0.0
bias: bool = False
n_exp: int = 8
top_k: int = 2
use_aux_loss: bool = True
@ -920,25 +963,29 @@ class GPTConfig:
use_switch_tfm_init: bool = True
switch_tfm_init_scale: float = 1.0
router_use_full_prec: bool = True
bias: bool = False
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
blocks = []
for i in range(config.n_layer):
use_moe = (config.n_exp > 1) and (i % config.stride == 0)
blocks.append(Block(config, use_moe=use_moe))
if config.n_exp == 1:
blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
else:
blocks = []
for i in range(config.n_layer):
use_moe = (i % config.stride) == 0
blocks.append(Block(config, use_moe=use_moe))
blocks = nn.ModuleList(blocks)
self.transformer = nn.ModuleDict(dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
wpe=nn.Embedding(config.block_size, config.n_embd),
drop=nn.Dropout(config.dropout),
h=nn.ModuleList(blocks),
h=blocks,
ln_f=LayerNorm(config.n_embd, bias=config.bias),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight
def forward(self, idx, targets=None):
B, T = idx.size()
@ -949,10 +996,18 @@ class GPT(nn.Module):
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
if self.config.n_exp > 1 and self.config.use_aux_loss:
loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss()
MANAGER.reset_aux_loss()
if self.config.n_exp > 1 and self.config.use_router_z_loss:
loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss()
MANAGER.reset_router_z_loss()
else:
logits = self.lm_head(x[:, [-1], :])
loss = None
return logits, loss
"""
@ -1024,8 +1079,8 @@ def main():
parser.add_argument(
"--tokenizer",
choices=["gpt2", "cache"],
default="gpt2",
help="Tokenizer source for export: gpt2 uses tiktoken; cache uses ~/.cache/nanochat/tokenizer",
default="cache",
help="Tokenizer source for export: cache uses ~/.cache/nanochat/tokenizer (default); gpt2 uses tiktoken",
)
args = parser.parse_args()