diff --git a/nanochat/to_hf.py b/nanochat/to_hf.py index 9ea5e82..c8704f9 100644 --- a/nanochat/to_hf.py +++ b/nanochat/to_hf.py @@ -14,7 +14,6 @@ import argparse import json import os import shutil -import sys from dataclasses import fields as dataclass_fields from pathlib import Path from typing import Optional @@ -31,11 +30,7 @@ except ImportError as exc: from nanochat.common import get_base_dir from nanochat.tokenizer import RustBPETokenizer - -# standard.py expects `manager` to be importable; alias the package module to satisfy that import. -import nanochat.manager as _manager -sys.modules.setdefault("manager", _manager) -from nanochat.standard import GPT, GPTConfig +from nanochat.gpt import GPT, GPTConfig CHAT_TEMPLATE = ( "<|bos|>{% for message in messages %}{% if message['role'] == 'user' %}" @@ -168,7 +163,6 @@ def normalize_config(cfg_kwargs: dict) -> dict: cfg = dict(cfg_kwargs) if "sequence_len" in cfg and "block_size" not in cfg: cfg["block_size"] = cfg.pop("sequence_len") - cfg.pop("n_kv_head", None) return cfg @@ -722,6 +716,33 @@ import torch.nn as nn import torch.nn.functional as F +class MOEManager: + def __init__(self): + self.aux_loss = [] + self.router_z_loss = [] + + def add_aux_loss(self, loss): + self.aux_loss.append(loss) + + def reset_aux_loss(self): + self.aux_loss = [] + + def add_router_z_loss(self, loss): + self.router_z_loss.append(loss) + + def reset_router_z_loss(self): + self.router_z_loss = [] + + def aggregate_aux_loss(self): + return sum(self.aux_loss) + + def aggregate_router_z_loss(self): + return sum(self.router_z_loss) + + +MANAGER = MOEManager() + + class LayerNorm(nn.Module): def __init__(self, ndim, bias): super().__init__() @@ -777,12 +798,14 @@ class Router(nn.Module): super().__init__() self.top_k = config.top_k self.n_exp = config.n_exp - assert 1 <= self.top_k <= config.n_exp + assert self.top_k >= 1 and self.top_k <= config.n_exp self.use_noisy_top_k = config.use_noisy_top_k self.train_capacity = config.train_capacity self.eval_capacity = config.eval_capacity self.min_capacity = config.min_capacity self.router_use_full_prec = config.router_use_full_prec + self.use_aux_loss = config.use_aux_loss + self.use_router_z_loss = config.use_router_z_loss self.w_g = nn.Linear(config.n_embd, config.n_exp, bias=False) self.w_noise = nn.Linear(config.n_embd, config.n_exp, bias=False) if self.use_noisy_top_k else None @@ -797,10 +820,16 @@ class Router(nn.Module): noise = F.softplus(self.w_noise(x)) noise *= torch.randn_like(noise) logits += noise + if self.use_router_z_loss: + z_loss = self.compute_router_z_loss(logits) + MANAGER.add_router_z_loss(z_loss) top_k_logits, top_k_indices = logits.topk(self.top_k, dim=-1) router_probs = torch.full_like(logits, float("-inf")) router_probs.scatter_(-1, top_k_indices, top_k_logits) router_probs = F.softmax(router_probs, dim=-1) + if self.use_aux_loss: + aux_loss = self.compute_aux_loss(router_probs, top_k_indices) + MANAGER.add_aux_loss(aux_loss) exp_capacity = self.get_capacity(num_tokens) exp_mask = F.one_hot(top_k_indices, num_classes=self.n_exp) exp_mask = exp_mask.view(num_tokens, self.top_k, self.n_exp) @@ -809,13 +838,26 @@ class Router(nn.Module): exp_rank = torch.cumsum(exp_rank, dim=0) - 1 exp_rank = exp_rank.reshape(self.top_k, num_tokens, self.n_exp) exp_mask *= torch.lt(exp_rank, exp_capacity) + used_capacity = torch.sum(exp_mask, dim=(0, 1)) exp_rank = torch.sum(exp_mask * exp_rank, dim=-1) router_probs = router_probs.view(num_tokens, self.n_exp)[None, :] exp_weights = exp_mask * router_probs exp_rank_sc = F.one_hot(exp_rank, num_classes=exp_capacity) cb_weight = torch.sum(exp_weights.unsqueeze(3) * exp_rank_sc.unsqueeze(2), dim=0) sec_mask = cb_weight.bool() - return cb_weight, sec_mask + return used_capacity, cb_weight, sec_mask + + def compute_aux_loss(self, expert_probs: torch.Tensor, indices: torch.Tensor): + with torch.no_grad(): + one_hot_indices = F.one_hot(indices, num_classes=self.n_exp) + one_hot_indices = torch.sum(one_hot_indices.float(), dim=2) + tokens_per_expert = torch.mean(one_hot_indices.float(), dim=(0, 1)) + prob_per_expert = torch.mean(expert_probs.float(), dim=(0, 1)) + return self.n_exp * torch.sum(prob_per_expert * tokens_per_expert) + + def compute_router_z_loss(self, logits: torch.Tensor): + z_loss = torch.logsumexp(logits, dim=-1) ** 2.0 + return torch.mean(z_loss) def get_capacity(self, tokens_per_batch): capacity_factor = self.train_capacity if self.training else self.eval_capacity @@ -857,13 +899,13 @@ class MOELayer(nn.Module): def forward(self, x): B, T, n_embd = x.size() num_tokens = B * T - exp_weight, exp_mask = self.router(x) + used_capacity, exp_weight, exp_mask = self.router(x) x = x.view(num_tokens, n_embd) exp_batches = exp_mask.permute(1, 2, 0).type_as(x) @ x exp_out = self.experts(exp_batches) exp_weight = exp_weight.view(num_tokens, -1) exp_out = exp_out.view(-1, n_embd) - output = exp_weight.type_as(exp_out) @ exp_out + output = exp_weight @ exp_out return output.view(B, T, n_embd) @@ -899,13 +941,14 @@ class Block(nn.Module): @dataclass class GPTConfig: + sequence_len: int = 1024 block_size: int = 1024 vocab_size: int = 50304 n_layer: int = 6 n_head: int = 6 + n_kv_head: int = 6 n_embd: int = 384 dropout: float = 0.0 - bias: bool = False n_exp: int = 8 top_k: int = 2 use_aux_loss: bool = True @@ -920,25 +963,29 @@ class GPTConfig: use_switch_tfm_init: bool = True switch_tfm_init_scale: float = 1.0 router_use_full_prec: bool = True + bias: bool = False class GPT(nn.Module): def __init__(self, config): super().__init__() self.config = config - blocks = [] - for i in range(config.n_layer): - use_moe = (config.n_exp > 1) and (i % config.stride == 0) - blocks.append(Block(config, use_moe=use_moe)) + if config.n_exp == 1: + blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) + else: + blocks = [] + for i in range(config.n_layer): + use_moe = (i % config.stride) == 0 + blocks.append(Block(config, use_moe=use_moe)) + blocks = nn.ModuleList(blocks) self.transformer = nn.ModuleDict(dict( wte=nn.Embedding(config.vocab_size, config.n_embd), wpe=nn.Embedding(config.block_size, config.n_embd), drop=nn.Dropout(config.dropout), - h=nn.ModuleList(blocks), + h=blocks, ln_f=LayerNorm(config.n_embd, bias=config.bias), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - self.transformer.wte.weight = self.lm_head.weight def forward(self, idx, targets=None): B, T = idx.size() @@ -949,10 +996,18 @@ class GPT(nn.Module): for block in self.transformer.h: x = block(x) x = self.transformer.ln_f(x) - logits = self.lm_head(x) - loss = None if targets is not None: + logits = self.lm_head(x) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + if self.config.n_exp > 1 and self.config.use_aux_loss: + loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss() + MANAGER.reset_aux_loss() + if self.config.n_exp > 1 and self.config.use_router_z_loss: + loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss() + MANAGER.reset_router_z_loss() + else: + logits = self.lm_head(x[:, [-1], :]) + loss = None return logits, loss """ @@ -1024,8 +1079,8 @@ def main(): parser.add_argument( "--tokenizer", choices=["gpt2", "cache"], - default="gpt2", - help="Tokenizer source for export: gpt2 uses tiktoken; cache uses ~/.cache/nanochat/tokenizer", + default="cache", + help="Tokenizer source for export: cache uses ~/.cache/nanochat/tokenizer (default); gpt2 uses tiktoken", ) args = parser.parse_args()