mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-20 11:53:13 +00:00
to_hf adjusted to current imple
This commit is contained in:
parent
952ea5137a
commit
8f1378235e
|
|
@ -14,7 +14,6 @@ import argparse
|
|||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from dataclasses import fields as dataclass_fields
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
|
@ -31,11 +30,7 @@ except ImportError as exc:
|
|||
|
||||
from nanochat.common import get_base_dir
|
||||
from nanochat.tokenizer import RustBPETokenizer
|
||||
|
||||
# standard.py expects `manager` to be importable; alias the package module to satisfy that import.
|
||||
import nanochat.manager as _manager
|
||||
sys.modules.setdefault("manager", _manager)
|
||||
from nanochat.standard import GPT, GPTConfig
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
|
||||
CHAT_TEMPLATE = (
|
||||
"<|bos|>{% for message in messages %}{% if message['role'] == 'user' %}"
|
||||
|
|
@ -168,7 +163,6 @@ def normalize_config(cfg_kwargs: dict) -> dict:
|
|||
cfg = dict(cfg_kwargs)
|
||||
if "sequence_len" in cfg and "block_size" not in cfg:
|
||||
cfg["block_size"] = cfg.pop("sequence_len")
|
||||
cfg.pop("n_kv_head", None)
|
||||
return cfg
|
||||
|
||||
|
||||
|
|
@ -722,6 +716,33 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class MOEManager:
|
||||
def __init__(self):
|
||||
self.aux_loss = []
|
||||
self.router_z_loss = []
|
||||
|
||||
def add_aux_loss(self, loss):
|
||||
self.aux_loss.append(loss)
|
||||
|
||||
def reset_aux_loss(self):
|
||||
self.aux_loss = []
|
||||
|
||||
def add_router_z_loss(self, loss):
|
||||
self.router_z_loss.append(loss)
|
||||
|
||||
def reset_router_z_loss(self):
|
||||
self.router_z_loss = []
|
||||
|
||||
def aggregate_aux_loss(self):
|
||||
return sum(self.aux_loss)
|
||||
|
||||
def aggregate_router_z_loss(self):
|
||||
return sum(self.router_z_loss)
|
||||
|
||||
|
||||
MANAGER = MOEManager()
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, ndim, bias):
|
||||
super().__init__()
|
||||
|
|
@ -777,12 +798,14 @@ class Router(nn.Module):
|
|||
super().__init__()
|
||||
self.top_k = config.top_k
|
||||
self.n_exp = config.n_exp
|
||||
assert 1 <= self.top_k <= config.n_exp
|
||||
assert self.top_k >= 1 and self.top_k <= config.n_exp
|
||||
self.use_noisy_top_k = config.use_noisy_top_k
|
||||
self.train_capacity = config.train_capacity
|
||||
self.eval_capacity = config.eval_capacity
|
||||
self.min_capacity = config.min_capacity
|
||||
self.router_use_full_prec = config.router_use_full_prec
|
||||
self.use_aux_loss = config.use_aux_loss
|
||||
self.use_router_z_loss = config.use_router_z_loss
|
||||
self.w_g = nn.Linear(config.n_embd, config.n_exp, bias=False)
|
||||
self.w_noise = nn.Linear(config.n_embd, config.n_exp, bias=False) if self.use_noisy_top_k else None
|
||||
|
||||
|
|
@ -797,10 +820,16 @@ class Router(nn.Module):
|
|||
noise = F.softplus(self.w_noise(x))
|
||||
noise *= torch.randn_like(noise)
|
||||
logits += noise
|
||||
if self.use_router_z_loss:
|
||||
z_loss = self.compute_router_z_loss(logits)
|
||||
MANAGER.add_router_z_loss(z_loss)
|
||||
top_k_logits, top_k_indices = logits.topk(self.top_k, dim=-1)
|
||||
router_probs = torch.full_like(logits, float("-inf"))
|
||||
router_probs.scatter_(-1, top_k_indices, top_k_logits)
|
||||
router_probs = F.softmax(router_probs, dim=-1)
|
||||
if self.use_aux_loss:
|
||||
aux_loss = self.compute_aux_loss(router_probs, top_k_indices)
|
||||
MANAGER.add_aux_loss(aux_loss)
|
||||
exp_capacity = self.get_capacity(num_tokens)
|
||||
exp_mask = F.one_hot(top_k_indices, num_classes=self.n_exp)
|
||||
exp_mask = exp_mask.view(num_tokens, self.top_k, self.n_exp)
|
||||
|
|
@ -809,13 +838,26 @@ class Router(nn.Module):
|
|||
exp_rank = torch.cumsum(exp_rank, dim=0) - 1
|
||||
exp_rank = exp_rank.reshape(self.top_k, num_tokens, self.n_exp)
|
||||
exp_mask *= torch.lt(exp_rank, exp_capacity)
|
||||
used_capacity = torch.sum(exp_mask, dim=(0, 1))
|
||||
exp_rank = torch.sum(exp_mask * exp_rank, dim=-1)
|
||||
router_probs = router_probs.view(num_tokens, self.n_exp)[None, :]
|
||||
exp_weights = exp_mask * router_probs
|
||||
exp_rank_sc = F.one_hot(exp_rank, num_classes=exp_capacity)
|
||||
cb_weight = torch.sum(exp_weights.unsqueeze(3) * exp_rank_sc.unsqueeze(2), dim=0)
|
||||
sec_mask = cb_weight.bool()
|
||||
return cb_weight, sec_mask
|
||||
return used_capacity, cb_weight, sec_mask
|
||||
|
||||
def compute_aux_loss(self, expert_probs: torch.Tensor, indices: torch.Tensor):
|
||||
with torch.no_grad():
|
||||
one_hot_indices = F.one_hot(indices, num_classes=self.n_exp)
|
||||
one_hot_indices = torch.sum(one_hot_indices.float(), dim=2)
|
||||
tokens_per_expert = torch.mean(one_hot_indices.float(), dim=(0, 1))
|
||||
prob_per_expert = torch.mean(expert_probs.float(), dim=(0, 1))
|
||||
return self.n_exp * torch.sum(prob_per_expert * tokens_per_expert)
|
||||
|
||||
def compute_router_z_loss(self, logits: torch.Tensor):
|
||||
z_loss = torch.logsumexp(logits, dim=-1) ** 2.0
|
||||
return torch.mean(z_loss)
|
||||
|
||||
def get_capacity(self, tokens_per_batch):
|
||||
capacity_factor = self.train_capacity if self.training else self.eval_capacity
|
||||
|
|
@ -857,13 +899,13 @@ class MOELayer(nn.Module):
|
|||
def forward(self, x):
|
||||
B, T, n_embd = x.size()
|
||||
num_tokens = B * T
|
||||
exp_weight, exp_mask = self.router(x)
|
||||
used_capacity, exp_weight, exp_mask = self.router(x)
|
||||
x = x.view(num_tokens, n_embd)
|
||||
exp_batches = exp_mask.permute(1, 2, 0).type_as(x) @ x
|
||||
exp_out = self.experts(exp_batches)
|
||||
exp_weight = exp_weight.view(num_tokens, -1)
|
||||
exp_out = exp_out.view(-1, n_embd)
|
||||
output = exp_weight.type_as(exp_out) @ exp_out
|
||||
output = exp_weight @ exp_out
|
||||
return output.view(B, T, n_embd)
|
||||
|
||||
|
||||
|
|
@ -899,13 +941,14 @@ class Block(nn.Module):
|
|||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
sequence_len: int = 1024
|
||||
block_size: int = 1024
|
||||
vocab_size: int = 50304
|
||||
n_layer: int = 6
|
||||
n_head: int = 6
|
||||
n_kv_head: int = 6
|
||||
n_embd: int = 384
|
||||
dropout: float = 0.0
|
||||
bias: bool = False
|
||||
n_exp: int = 8
|
||||
top_k: int = 2
|
||||
use_aux_loss: bool = True
|
||||
|
|
@ -920,25 +963,29 @@ class GPTConfig:
|
|||
use_switch_tfm_init: bool = True
|
||||
switch_tfm_init_scale: float = 1.0
|
||||
router_use_full_prec: bool = True
|
||||
bias: bool = False
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
blocks = []
|
||||
for i in range(config.n_layer):
|
||||
use_moe = (config.n_exp > 1) and (i % config.stride == 0)
|
||||
blocks.append(Block(config, use_moe=use_moe))
|
||||
if config.n_exp == 1:
|
||||
blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
|
||||
else:
|
||||
blocks = []
|
||||
for i in range(config.n_layer):
|
||||
use_moe = (i % config.stride) == 0
|
||||
blocks.append(Block(config, use_moe=use_moe))
|
||||
blocks = nn.ModuleList(blocks)
|
||||
self.transformer = nn.ModuleDict(dict(
|
||||
wte=nn.Embedding(config.vocab_size, config.n_embd),
|
||||
wpe=nn.Embedding(config.block_size, config.n_embd),
|
||||
drop=nn.Dropout(config.dropout),
|
||||
h=nn.ModuleList(blocks),
|
||||
h=blocks,
|
||||
ln_f=LayerNorm(config.n_embd, bias=config.bias),
|
||||
))
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
self.transformer.wte.weight = self.lm_head.weight
|
||||
|
||||
def forward(self, idx, targets=None):
|
||||
B, T = idx.size()
|
||||
|
|
@ -949,10 +996,18 @@ class GPT(nn.Module):
|
|||
for block in self.transformer.h:
|
||||
x = block(x)
|
||||
x = self.transformer.ln_f(x)
|
||||
logits = self.lm_head(x)
|
||||
loss = None
|
||||
if targets is not None:
|
||||
logits = self.lm_head(x)
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
if self.config.n_exp > 1 and self.config.use_aux_loss:
|
||||
loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss()
|
||||
MANAGER.reset_aux_loss()
|
||||
if self.config.n_exp > 1 and self.config.use_router_z_loss:
|
||||
loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss()
|
||||
MANAGER.reset_router_z_loss()
|
||||
else:
|
||||
logits = self.lm_head(x[:, [-1], :])
|
||||
loss = None
|
||||
return logits, loss
|
||||
"""
|
||||
|
||||
|
|
@ -1024,8 +1079,8 @@ def main():
|
|||
parser.add_argument(
|
||||
"--tokenizer",
|
||||
choices=["gpt2", "cache"],
|
||||
default="gpt2",
|
||||
help="Tokenizer source for export: gpt2 uses tiktoken; cache uses ~/.cache/nanochat/tokenizer",
|
||||
default="cache",
|
||||
help="Tokenizer source for export: cache uses ~/.cache/nanochat/tokenizer (default); gpt2 uses tiktoken",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user