nanochat/nanochat/gpt.py
2026-01-06 05:50:48 +00:00

732 lines
35 KiB
Python

"""
GPT model (rewrite, a lot simpler)
Notable features:
- rotary embeddings (and no positional embeddings)
- QK norm
- tied weights for token embedding and lm_head
- relu^2 activation in MLP
- norm after token embedding
- no learnable params in rmsnorm
- no bias in linear layers
- Group-Query Attention (GQA) support for more efficient inference
"""
import math
import inspect
from functools import partial
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
# ========== MOE ADDITION START ==========
from nanochat.manager import MANAGER
from contextlib import nullcontext
# ========== MOE ADDITION END ==========
@dataclass
class GPTConfig:
sequence_len: int = 1024
block_size: int = 1024 # alias for sequence_len, used by CausalSelfAttention
vocab_size: int = 50304
n_layer: int = 6
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (GQA)
n_embd: int = 384
dropout: float = 0.0 # dropout rate
# ========== MOE ADDITION START ==========
# MoE-related configs (added for MoE support)
n_exp: int = 8 # if n_exp = 1 we just use regular MLP layers
top_k: int = 2 # number of active experts
use_aux_loss: bool = True # apply auxiliary loss (from Switch Transformer)
use_router_z_loss: bool = True # apply router z loss (from ST-MoE)
use_noisy_top_k: bool = False
aux_loss_weight: float = 0.01 # default from Switch Transformer
router_z_loss_weight: float = 0.001 # default from ST-MoE
train_capacity: float = 1.25 # default from ST-MoE
eval_capacity: float = 2.0
min_capacity: int = 4 # minimum batch size per expert
stride: int = 2 # one in every stride layers uses MoE
use_switch_tfm_init: bool = True # use weight init scheme from Switch Transformer
switch_tfm_init_scale: float = 1.0
router_use_full_prec: bool = True # use float32 in router
bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
# ========== MOE ADDITION END ==========
class LayerNorm(nn.Module):
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# regularization
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash:
# efficient attention using Flash Attention CUDA kernels
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
else:
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
class Router(nn.Module):
def __init__(self, config):
super().__init__()
# router settings
self.top_k = config.top_k
self.n_exp = config.n_exp
assert self.top_k >= 1 and self.top_k <= config.n_exp
self.use_noisy_top_k = config.use_noisy_top_k
self.train_capacity = config.train_capacity
self.eval_capacity = config.eval_capacity
self.min_capacity = config.min_capacity
self.router_use_full_prec = config.router_use_full_prec
# auxiliary / load balancing loss settings
self.use_aux_loss = config.use_aux_loss
self.use_router_z_loss = config.use_router_z_loss
# linear projection for (noisy) softmax gating
# no bias is used, see page 4 eq (4) in (https://arxiv.org/abs/1701.06538)
self.w_g = nn.Linear(config.n_embd, config.n_exp, bias=False)
self.w_noise = nn.Linear(config.n_embd, config.n_exp, bias=False) if self.use_noisy_top_k else None
def forward(self, x):
# optionally run the router in full precision to avoid instability during training
# see discussion on pg. 9 here: https://arxiv.org/abs/2101.03961
# setting enabled to False in autocast automatically puts everything in float32
device_type = 'cuda' if torch.cuda.is_available() else 'cpu' # for later use in torch.autocast
ctx = nullcontext() if not self.router_use_full_prec else torch.amp.autocast(device_type=device_type, enabled=False)
with ctx:
B, T, _ = x.size()
num_tokens = B * T
# eq (4) in (https://arxiv.org/abs/1701.06538)
logits = self.w_g(x) # [B, T, n_exp]
if self.use_noisy_top_k:
# optionally add noise into the router
noise = F.softplus(self.w_noise(x))
noise *= torch.randn_like(noise)
logits += noise
# router z loss, computed on logits (before softmax)
# this loss prevents router logits from becoming too large
if self.use_router_z_loss:
z_loss = self.compute_router_z_loss(logits)
MANAGER.add_router_z_loss(z_loss)
# find top k experts for each token
top_k_logits, top_k_indices = logits.topk(self.top_k, dim=-1) # [B, T, k]
# normalize expert probabilities
# Question: should we normalize over all experts or just top-k?
# we choose to normalize over top-k, other option is commented out below
# Shazeer et al (https://arxiv.org/abs/1701.06538) does only topk
# see page 4 eq (3)-(5), the code for this is commented out below
router_probs = torch.full_like(logits, float('-inf')) # [B, T, n_exp]
router_probs.scatter_(-1, top_k_indices, top_k_logits)
router_probs = F.softmax(router_probs, dim=-1)
# # normalize all router logits (not just top-k) via softmax
# router_probs = F.softmax(logits, dim=-1)
# compute auxiliary load balancing loss
# this loss encourages equal probability assigned to each expert
# and equal load balancing of tokens assigned to each expert
if self.use_aux_loss:
aux_loss = self.compute_aux_loss(router_probs, top_k_indices)
MANAGER.add_aux_loss(aux_loss)
# compute expert capacity
exp_capacity = self.get_capacity(num_tokens)
# make a multi-hot mask of chosen experts, size [B, T, n_exp]
# entries are 0 if expert not chosen and 1 if expert chosen
exp_mask = F.one_hot(top_k_indices, num_classes=self.n_exp) # [B, T, k, n_exp]
exp_mask = exp_mask.view(num_tokens, self.top_k, self.n_exp) # [B * T, k, n_exp]
exp_mask = exp_mask.permute(1, 0, 2) # [k, B * T, n_exp]
# compute cumulative sum of each token over experts, this stores
# the index of each token within the batch of each expert
# NOTE: cumsum should count all top-1 first, top-2 second, etc.
# so that we prioritize top experts when dropping tokens (this is
# done by putting k dimension first for the reshape operation)
exp_rank = exp_mask.reshape(self.top_k * num_tokens, self.n_exp) # [k * B * T, n_exp]
exp_rank = torch.cumsum(exp_rank, dim=0) - 1 # cumulative sum of expert selections [k * B * T, n_exp]
exp_rank = exp_rank.reshape(self.top_k, num_tokens, self.n_exp) # [k, B * T, n_exp]
# mask out (set to zero) entries that go beyond expert capacity
# compute amount of used capacity by taking a sum over mask
exp_mask *= torch.lt(exp_rank, exp_capacity) # [k, B * T, n_exp]
used_capacity = torch.sum(exp_mask, dim=(0, 1)) # [n_exp]
# mask rank to only include tokens that are selected
# perform a sum so each row only contains index of token
# for the expert that is selected in that row
# result is a matrix that contains the position of each token
# in the batch of its corresponding expert
exp_rank = torch.sum(exp_mask * exp_rank, dim=-1) # [k, B * T]
# mask probabilities to only include selected experts
router_probs = router_probs.view(num_tokens, self.n_exp)[None, :] # [1, B * T, n_exp]
exp_weights = exp_mask * router_probs # [k, B * T, n_exp]
# convert rank into one-hot vectors over the available capacity
# stores the position of each token within the capacity of the selected expert
exp_rank_sc = F.one_hot(exp_rank, num_classes=exp_capacity) # [k, B * T, exp_capacity]
# create a vector that stores, for each token, the weight of selected
# experts at token's position in the capacity of that expert
# size of tensor is [B * T, n_exp, exp_capacity]
cb_weight = torch.sum(exp_weights.unsqueeze(3) * exp_rank_sc.unsqueeze(2), dim=0)
sec_mask = cb_weight.bool() # binary mask of selected experts for each token
return used_capacity, cb_weight, sec_mask
def compute_aux_loss(self, expert_probs: torch.Tensor, indices: torch.Tensor):
"""
Computes Switch Transformer auxiliary loss (https://arxiv.org/abs/2101.03961)
See equations (4)-(6) on page 7
"""
# equation (5): compute ratio of tokens allocated to each expert
# total number of tokens is defined as total tokens in batch * k
# (k = 1) for the Switch Transformer
with torch.no_grad():
one_hot_indices = F.one_hot(indices, num_classes=self.n_exp) # [B, T, k, n_exp]
one_hot_indices = torch.sum(one_hot_indices.float(), dim=2) # [B, T, n_exp] (sum over k dimension)
tokens_per_expert = torch.mean(one_hot_indices.float(), dim=(0, 1))
# equation (6): compute ratio of router probability allocated to each expert
prob_per_expert = torch.mean(expert_probs.float(), dim=(0, 1))
# equation (4): take a scaled dot product between prob/token allocation vectors
# multiply the result by the number of experts
return self.n_exp * torch.sum(prob_per_expert * tokens_per_expert)
def compute_router_z_loss(self, logits: torch.Tensor):
"""
Computes ST-MoE router z loss (https://arxiv.org/abs/2202.08906)
See equation (5) on page 7
"""
# exponentiate logits, sum logits of each expert, take log, and square
# code below is the same as:
# > z_loss = torch.exp(logits)
# > z_loss = torch.sum(z_loss, dim=-1)
# > z_loss = torch.log(z_loss) ** 2.0
z_loss = torch.logsumexp(logits, dim=-1) ** 2.0 # [B, T, n_exp]
# sum over all tokens and divide by total number of tokens
return torch.mean(z_loss)
def get_capacity(self, tokens_per_batch):
# expert capacity is given by (tokens_per_batch / num_experts) * capacity_factor
# see eq (3) in Switch Transformer (https://arxiv.org/abs/2101.03961)
capacity_factor = self.train_capacity if self.training else self.eval_capacity
capacity = math.floor(self.top_k * capacity_factor * tokens_per_batch / self.n_exp)
capacity += capacity % 2 # make sure capacity is an even number
capacity = max(capacity, self.min_capacity) # use min capacity
assert capacity > 0
return int(capacity)
class MLPExperts(nn.Module):
"""
implementation of multiple MLP-based experts that can process input
in batch -- based upon ColossalAI OpenMoE but simple, has optional bias, and
uses a bmm instead of a loop over a mm for each expert to improve efficiency
link: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/moe/experts.py
"""
def __init__(self, config):
# TODO: add param init
super().__init__()
self.bias = config.bias
self.c_fc = nn.Parameter(torch.empty(config.n_exp, config.n_embd, 4 * config.n_embd))
self.c_proj = nn.Parameter(torch.empty(config.n_exp, 4 * config.n_embd, config.n_embd))
self.fc_bias = nn.Parameter(torch.empty(config.n_exp, 1, 4 * config.n_embd)) if self.bias else None
self.proj_bias = nn.Parameter(torch.empty(config.n_exp, 1, config.n_embd)) if self.bias else None
self.gelu = nn.GELU()
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = torch.bmm(x, self.c_fc)
if self.bias:
x += self.fc_bias
x = self.gelu(x)
x = torch.bmm(x, self.c_proj)
if self.bias:
x += self.proj_bias
x = self.dropout(x)
return x
class MOELayer(nn.Module):
def __init__(self, config):
super().__init__()
self.router = Router(config) # (noisy) top k router
self.experts = MLPExperts(config) # group of MLPs (experts)
def forward(self, x: torch.Tensor):
B, T, n_embd = x.size() # track original shape of input
num_tokens = (B * T)
# pass each token through the router
used_capacity, exp_weight, exp_mask = self.router(x)
# flatten out the input
x = x.view(num_tokens, n_embd)
# reshape tokens into batches for each expert
# [n_exp, exp_capacity, B * T] * [B * T, n_embd] -> [n_exp, exp_capacity, n_embd]
exp_batches = exp_mask.permute(1, 2, 0).type_as(x) @ x
# compute expert output
exp_out = self.experts(exp_batches) # [n_exp, exp_capacity, n_embd]
# aggregate expert outputs based on router weights
# eq (2) on page 4 of ST-MoE (https://arxiv.org/abs/2202.08906)
# similar equations are used for other MoE papers
exp_weight = exp_weight.view(num_tokens, -1) # [B * T, n_exp * exp_capacity]
exp_out = exp_out.view(-1, n_embd) # [n_exp * exp_capacity, n_embd]
output = exp_weight @ exp_out # [B * T, n_embd]
# resize output before return
return output.view(B, T, n_embd)
class Block(nn.Module):
def __init__(self, config, use_moe=False):
super().__init__()
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
self.attn = CausalSelfAttention(config)
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
if use_moe:
self.mlp = MOELayer(config)
else:
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
assert config.vocab_size is not None
assert config.block_size is not None
self.config = config
if config.n_exp == 1:
# create normal transformer blocks
blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
else:
# create transformer blocks, placing an MoE block every <stride> layers
blocks = []
for i in range(config.n_layer):
# TODO: how to implement this?
# should we change below to i + 1 ?
use_moe = (i % config.stride) == 0
blocks.append(Block(config, use_moe=use_moe))
blocks = nn.ModuleList(blocks)
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
wpe = nn.Embedding(config.block_size, config.n_embd),
drop = nn.Dropout(config.dropout),
h = blocks,
ln_f = LayerNorm(config.n_embd, bias=config.bias),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# with weight tying when using torch.compile() some warnings get generated:
# "UserWarning: functional_call was passed multiple values for tied weights.
# This behavior is deprecated and will be an error in future versions"
# not 100% sure what this is, so far seems to be harmless. TODO investigate
# init all weights
# optionall use switch transformer special init scheme for experts
# See pg. 10 here: https://arxiv.org/abs/2101.03961
self.apply(self._init_weights)
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith('c_proj.weight') or pn.endswith('experts.c_proj'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
# report number of parameters
print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
def get_num_params(self, non_embedding=True):
"""
Return the number of parameters in the model.
For non-embedding count (default), the position embeddings get subtracted.
The token embeddings would too, except due to the parameter sharing these
params are actually used as weights in the final layer, so we include them.
"""
n_params = sum(p.numel() for p in self.parameters())
if non_embedding:
n_params -= self.transformer.wpe.weight.numel()
return n_params
@torch.no_grad()
def _init_weights(self, module):
# optionally use switch transformer-style initialization
# see page 10 for switch init explanation: https://arxiv.org/abs/2101.03961
if isinstance(module, nn.Linear):
if self.config.use_switch_tfm_init:
scale = self.config.switch_tfm_init_scale
# linear layers have flipped dimensions in torch
# size of weights is [out_dim, in_dim]
w_fan_in = module.weight.shape[-1]
w_std = (scale / w_fan_in) ** 0.5
torch.nn.init.trunc_normal_(
module.weight,
mean=0.0,
std=w_std,
a=-2*w_std,
b=2*w_std,
)
else:
# perform standard (normal) initialization of weights
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
# always initialize bias to zero
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, MLPExperts):
# we have to init expert weights manually because
# nn.Parameter is not a type of module in torch
if self.config.use_switch_tfm_init:
scale = self.config.switch_tfm_init_scale
c_fc_fan_in = module.c_fc.shape[-2]
c_fc_std = (scale / c_fc_fan_in) ** 0.5
torch.nn.init.trunc_normal_(
module.c_fc,
mean=0.0,
std=c_fc_std,
a=-2*c_fc_std,
b=2*c_fc_std,
)
c_proj_fan_in = module.c_proj.shape[-2]
c_proj_std = (scale / c_proj_fan_in) ** 0.5
torch.nn.init.trunc_normal_(
module.c_proj,
mean=0.0,
std=c_proj_std,
a=-2*c_proj_std,
b=2*c_proj_std,
)
else:
# perform standard (normal) initialization of weights
torch.nn.init.normal_(module.c_fc, mean=0.0, std=0.02)
torch.nn.init.normal_(module.c_proj, mean=0.0, std=0.02)
# bias is always initialized to zero
if module.fc_bias is not None:
torch.nn.init.zeros_(module.fc_bias)
torch.nn.init.zeros_(module.proj_bias)
elif isinstance(module, nn.Embedding):
# just use standard initialization scheme for embedding always
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
# def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
# model_dim = self.config.n_embd
# ddp, rank, local_rank, world_size = get_dist_info()
# # Separate out all parameters into groups
# # ========== MOE MODIFICATION START ==========
# # Separate MoE expert parameters (3D) from regular matrix parameters (2D)
# # Muon optimizer only accepts 2D parameters, so we need to filter out 3D (MoE experts) and 1D (bias/norm) params
# matrix_params = []
# moe_params = [] # MoE expert parameters are 3D and need AdamW optimizer
# other_params = [] # 1D parameters (bias, norm weights) also go to AdamW
# for param in self.transformer.h.parameters():
# if param.ndim == 3: # MoE expert parameters: [n_exp, ...]
# moe_params.append(param)
# elif param.ndim == 2: # Regular 2D matrix parameters for Muon
# matrix_params.append(param)
# else: # 1D parameters (bias, norm weights) go to AdamW
# other_params.append(param)
# # ========== MOE MODIFICATION END ==========
# embedding_params = list(self.transformer.wte.parameters())
# lm_head_params = list(self.lm_head.parameters())
# # Create the AdamW optimizer for the embedding, lm_head, and MoE experts
# # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
# dmodel_lr_scale = (model_dim / 768) ** -0.5
# if rank == 0:
# print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
# if moe_params:
# print(f"Found {len(moe_params)} MoE expert parameters (3D) to optimize with AdamW")
# adam_groups = [
# dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
# dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
# ]
# # ========== MOE MODIFICATION START ==========
# # Add MoE expert parameters to AdamW optimizer (use matrix_lr for consistency)
# if moe_params:
# adam_groups.append(dict(params=moe_params, lr=matrix_lr * dmodel_lr_scale))
# # ========== MOE MODIFICATION END ==========
# adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
# AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
# adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
# # Create the Muon optimizer for the 2D linear layers only
# muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
# MuonFactory = DistMuon if ddp else Muon
# muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
# # Combine them the two optimizers into one list
# optimizers = [adamw_optimizer, muon_optimizer]
# for opt in optimizers:
# for group in opt.param_groups:
# group["initial_lr"] = group["lr"]
# return optimizers
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
# TODO: add expert config
# start with all of the candidate parameters
param_dict = {pn: p for pn, p in self.named_parameters()}
# filter out those that do not require grad
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
# add an extra check for "bias" string to account for bias terms in MoE layers
decay_params = [p for n, p in param_dict.items() if (p.dim() >= 2 and not n.endswith('bias'))]
nodecay_params = [p for n, p in param_dict.items() if (p.dim() < 2 or n.endswith('bias'))]
optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
# Create AdamW optimizer and use the fused version if it is available
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device_type == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
print(f"using fused AdamW: {use_fused}")
return optimizer
def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
# forward the GPT model itself
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
# add the auxiliary load balancing loss and router z loss to the main loss
if self.config.n_exp > 1 and self.config.use_aux_loss:
loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss()
MANAGER.reset_aux_loss()
if self.config.n_exp > 1 and self.config.use_router_z_loss:
loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss()
MANAGER.reset_router_z_loss()
else:
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None
return logits, loss
def crop_block_size(self, block_size):
# model surgery to decrease the block size if necessary
# e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
# but want to use a smaller block size for some smaller, simpler model
assert block_size <= self.config.block_size
self.config.block_size = block_size
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
for block in self.transformer.h:
if hasattr(block.attn, 'bias'):
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
@classmethod
def from_pretrained(cls, model_type, override_args=None):
assert not 'moe' in model_type, "Pretrained checkpoints not available for MoE"
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
override_args = override_args or {} # default to empty dict
# only dropout can be overridden see more notes below
assert all(k == 'dropout' for k in override_args)
from transformers import GPT2LMHeadModel
print("loading weights from pretrained gpt: %s" % model_type)
# n_layer, n_head and n_embd are determined from model_type
config_args = {
'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}[model_type]
print("forcing vocab_size=50257, block_size=1024, bias=True")
config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
config_args['bias'] = True # always True for GPT model checkpoints
# we can override the dropout rate, if desired
if 'dropout' in override_args:
print(f"overriding dropout rate to {override_args['dropout']}")
config_args['dropout'] = override_args['dropout']
# create a from-scratch initialized minGPT model
config = GPTConfig(**config_args)
model = GPT(config)
sd = model.state_dict()
sd_keys = sd.keys()
sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
# init a huggingface/transformers model
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
sd_hf = model_hf.state_dict()
# copy while ensuring all of the parameters are aligned and match in names and shapes
sd_keys_hf = sd_hf.keys()
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
# basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
# this means that we have to transpose these weights when we import them
assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
for k in sd_keys_hf:
if any(k.endswith(w) for w in transposed):
# special treatment for the Conv1D weights we need to transpose
assert sd_hf[k].shape[::-1] == sd[k].shape
with torch.no_grad():
sd[k].copy_(sd_hf[k].t())
else:
# vanilla copy over the other parameters
assert sd_hf[k].shape == sd[k].shape
with torch.no_grad():
sd[k].copy_(sd_hf[k])
return model
def estimate_mfu(self, fwdbwd_per_iter, dt):
""" estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
# first estimate the number of flops we do per iteration.
# see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
N = self.get_num_params()
cfg = self.config
L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
flops_per_token = 6*N + 12*L*H*Q*T
flops_per_fwdbwd = flops_per_token * T
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
# express our flops throughput as ratio of A100 bfloat16 peak flops
flops_achieved = flops_per_iter * (1.0/dt) # per second
flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
mfu = flops_achieved / flops_promised
return mfu
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
# forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)
return idx