mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-20 20:03:19 +00:00
Refactor MoEFeedForward class and remove ExpertFFN implementation
- Removed the ExpertFFN class to streamline the MoEFeedForward implementation. - Introduced shared weights for experts, enhancing flexibility in expert configurations. - Updated the forward method to utilize shared weights and improved tensor handling for statistics. - Added methods for managing token indices and shared forward computations, optimizing performance. - Enhanced state loading to accommodate new shared weights structure.
This commit is contained in:
parent
25d2573f47
commit
d7c62c2cfe
148
nanochat/gpt.py
148
nanochat/gpt.py
|
|
@ -130,20 +130,6 @@ class MLP(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class ExpertFFN(nn.Module):
|
||||
def __init__(self, input_dim, ffn_mult):
|
||||
super().__init__()
|
||||
hidden_dim = max(1, int(round(ffn_mult * input_dim)))
|
||||
self.c_fc = nn.Linear(input_dim, hidden_dim, bias=False)
|
||||
self.c_proj = nn.Linear(hidden_dim, input_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c_fc(x)
|
||||
x = F.relu(x).square()
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class MoEFeedForward(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
|
@ -158,9 +144,12 @@ class MoEFeedForward(nn.Module):
|
|||
self.hidden_dim = max(1, int(round(self.ffn_mult * self.model_dim)))
|
||||
self.routed_w1 = nn.Parameter(torch.empty(self.num_routed_experts, self.hidden_dim, self.model_dim))
|
||||
self.routed_w2 = nn.Parameter(torch.empty(self.num_routed_experts, self.model_dim, self.hidden_dim))
|
||||
self.shared_experts = nn.ModuleList(
|
||||
[ExpertFFN(self.model_dim, self.ffn_mult) for _ in range(self.num_shared_experts)]
|
||||
)
|
||||
if self.num_shared_experts > 0:
|
||||
self.shared_w1 = nn.Parameter(torch.empty(self.num_shared_experts, self.hidden_dim, self.model_dim))
|
||||
self.shared_w2 = nn.Parameter(torch.empty(self.num_shared_experts, self.model_dim, self.hidden_dim))
|
||||
else:
|
||||
self.register_parameter("shared_w1", None)
|
||||
self.register_parameter("shared_w2", None)
|
||||
self.register_buffer("router_bias", torch.zeros(self.num_routed_experts, dtype=torch.float32))
|
||||
self.register_buffer("ema_load", torch.zeros(self.num_routed_experts, dtype=torch.float32))
|
||||
self.balance_strength = 0.001
|
||||
|
|
@ -169,8 +158,40 @@ class MoEFeedForward(nn.Module):
|
|||
self._last_stats = None
|
||||
self.last_entropy = None
|
||||
self.last_load = None
|
||||
self._token_idx_capacity = 0
|
||||
self._token_idx_device = None
|
||||
self._token_idx_cache = None
|
||||
self._init_experts()
|
||||
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
legacy_prefix = prefix + "shared_experts."
|
||||
if any(key.startswith(legacy_prefix) for key in list(state_dict.keys())) and self.shared_w1 is not None:
|
||||
for idx in range(self.num_shared_experts):
|
||||
w1_key = f"{legacy_prefix}{idx}.c_fc.weight"
|
||||
w2_key = f"{legacy_prefix}{idx}.c_proj.weight"
|
||||
if w1_key in state_dict:
|
||||
self.shared_w1.data[idx].copy_(state_dict.pop(w1_key))
|
||||
if w2_key in state_dict:
|
||||
self.shared_w2.data[idx].copy_(state_dict.pop(w2_key))
|
||||
super()._load_from_state_dict(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
|
||||
def _balance_router(self, assignments, probs):
|
||||
if not self.training:
|
||||
self.last_load = None
|
||||
|
|
@ -188,24 +209,28 @@ class MoEFeedForward(nn.Module):
|
|||
self.router_bias.clamp_(-self.bias_clamp, self.bias_clamp)
|
||||
entropy = (-probs * torch.log(probs + 1e-9)).sum(dim=-1).mean()
|
||||
load_cpu = load.detach().cpu()
|
||||
entropy_cpu = entropy.detach().cpu()
|
||||
stats = {
|
||||
"tokens": probs.size(0),
|
||||
"load_std": load_cpu.std(unbiased=False).item(),
|
||||
"load_max": load_cpu.max().item(),
|
||||
"load_min": load_cpu.min().item(),
|
||||
"load_active_frac": (load_cpu > 0).float().mean().item(),
|
||||
"load_imbalance": torch.abs(load_cpu - (1.0 / self.num_routed_experts)).mean().item(),
|
||||
"load_std": load_cpu.std(unbiased=False),
|
||||
"load_max": load_cpu.max(),
|
||||
"load_min": load_cpu.min(),
|
||||
"load_active_frac": (load_cpu > 0).float().mean(),
|
||||
"load_imbalance": torch.abs(load_cpu - (1.0 / self.num_routed_experts)).mean(),
|
||||
"entropy": entropy_cpu,
|
||||
}
|
||||
self._last_stats = stats
|
||||
self._last_stats = {key: (value.detach() if torch.is_tensor(value) else value) for key, value in stats.items()}
|
||||
self.last_load = load_cpu
|
||||
self.last_entropy = entropy.item()
|
||||
self.last_entropy = entropy_cpu
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.shape
|
||||
x_flat = x.view(B * T, C)
|
||||
router_logits = self.router(x_flat).float()
|
||||
router_logits = router_logits + self.router_bias.to(router_logits.device)
|
||||
probs = torch.softmax(router_logits, dim=-1)
|
||||
router_logits = self.router(x_flat)
|
||||
router_logits = router_logits + self.router_bias.to(router_logits.device, dtype=router_logits.dtype)
|
||||
if x_flat.device.type in {"cuda", "mps"} and router_logits.dtype != torch.bfloat16:
|
||||
router_logits = router_logits.to(torch.bfloat16)
|
||||
probs = torch.softmax(router_logits, dim=-1).to(x_flat.dtype)
|
||||
topk_scores, topk_idx = torch.topk(probs, k=self.top_k, dim=-1)
|
||||
topk_scores = topk_scores / torch.clamp(topk_scores.sum(dim=-1, keepdim=True), min=1e-9)
|
||||
flat_assignments = topk_idx.reshape(-1)
|
||||
|
|
@ -214,19 +239,17 @@ class MoEFeedForward(nn.Module):
|
|||
mixture = self._dispatch_batched(flat_assignments, x_flat, topk_scores)
|
||||
|
||||
if self._last_stats is not None:
|
||||
self._last_stats["gate_mean"] = topk_scores.mean().item()
|
||||
self._last_stats["gate_std"] = topk_scores.std(unbiased=False).item()
|
||||
self._last_stats["gate_mean"] = topk_scores.mean().detach().cpu()
|
||||
self._last_stats["gate_std"] = topk_scores.std(unbiased=False).detach().cpu()
|
||||
confident = probs.max(dim=-1).values
|
||||
self._last_stats["router_confidence"] = confident.mean().item()
|
||||
self._last_stats["router_confidence"] = confident.mean().detach().cpu()
|
||||
self._last_stats["tokens"] = x_flat.size(0)
|
||||
self._last_stats["entropy"] = self.last_entropy if self.last_entropy is not None else 0.0
|
||||
if self.last_entropy is not None:
|
||||
self._last_stats["entropy"] = self.last_entropy
|
||||
|
||||
if self.num_shared_experts > 0:
|
||||
shared_sum = None
|
||||
for shared in self.shared_experts:
|
||||
out = shared(x_flat)
|
||||
shared_sum = out if shared_sum is None else shared_sum + out
|
||||
mixture = mixture + shared_sum / self.num_shared_experts
|
||||
shared = self._shared_forward(x_flat)
|
||||
if shared is not None:
|
||||
mixture = mixture + shared
|
||||
|
||||
return mixture.view(B, T, C)
|
||||
|
||||
|
|
@ -236,7 +259,7 @@ class MoEFeedForward(nn.Module):
|
|||
if total == 0:
|
||||
return torch.zeros_like(x_flat)
|
||||
|
||||
token_idx = torch.arange(num_tokens, device=x_flat.device).repeat_interleave(self.top_k)
|
||||
token_idx = self._get_token_indices(num_tokens, x_flat.device)
|
||||
selected_inputs = x_flat[token_idx]
|
||||
expert_ids = flat_assignments
|
||||
|
||||
|
|
@ -250,9 +273,43 @@ class MoEFeedForward(nn.Module):
|
|||
weights = topk_scores.unsqueeze(-1).to(routed.dtype)
|
||||
return torch.sum(routed * weights, dim=1)
|
||||
|
||||
def _get_token_indices(self, num_tokens, device):
|
||||
required_tokens = num_tokens
|
||||
if (
|
||||
self._token_idx_cache is None
|
||||
or self._token_idx_device != device
|
||||
or self._token_idx_capacity < required_tokens
|
||||
):
|
||||
new_capacity = max(required_tokens, max(1, self._token_idx_capacity * 2))
|
||||
idx = torch.arange(new_capacity, device=device, dtype=torch.long).repeat_interleave(self.top_k)
|
||||
self._token_idx_cache = idx
|
||||
self._token_idx_capacity = new_capacity
|
||||
self._token_idx_device = device
|
||||
total = required_tokens * self.top_k
|
||||
return self._token_idx_cache[:total]
|
||||
|
||||
|
||||
|
||||
def _shared_forward(self, x_flat):
|
||||
if self.num_shared_experts == 0:
|
||||
return None
|
||||
shared_w1 = self.shared_w1
|
||||
shared_w2 = self.shared_w2
|
||||
x_for_shared = x_flat.to(shared_w1.dtype)
|
||||
hidden = torch.einsum("ehd,td->eth", shared_w1, x_for_shared)
|
||||
hidden = F.relu(hidden).square()
|
||||
shared = torch.einsum("eth,edh->etd", hidden, shared_w2)
|
||||
return shared.mean(dim=0).to(x_flat.dtype)
|
||||
|
||||
def _init_experts(self):
|
||||
self._init_linear(self.routed_w1, fan_in=self.model_dim, fan_out=self.hidden_dim)
|
||||
self._init_linear(self.routed_w2, fan_in=self.hidden_dim, fan_out=self.model_dim)
|
||||
if self.shared_w1 is not None:
|
||||
for weight in self.shared_w1:
|
||||
self._init_linear(weight, fan_in=self.model_dim, fan_out=self.hidden_dim)
|
||||
if self.shared_w2 is not None:
|
||||
for weight in self.shared_w2:
|
||||
self._init_linear(weight, fan_in=self.hidden_dim, fan_out=self.model_dim)
|
||||
|
||||
@staticmethod
|
||||
def _init_linear(weight, fan_in, fan_out):
|
||||
|
|
@ -262,11 +319,18 @@ class MoEFeedForward(nn.Module):
|
|||
def get_stats(self):
|
||||
if self._last_stats is None:
|
||||
return None
|
||||
return dict(self._last_stats)
|
||||
stats = {}
|
||||
for key, value in self._last_stats.items():
|
||||
if torch.is_tensor(value):
|
||||
stats[key] = value.item()
|
||||
else:
|
||||
stats[key] = value
|
||||
return stats
|
||||
|
||||
def count_parameters(self):
|
||||
total = self.routed_w1.numel() + self.routed_w2.numel()
|
||||
total += sum(p.numel() for expert in self.shared_experts for p in expert.parameters())
|
||||
if self.shared_w1 is not None:
|
||||
total += self.shared_w1.numel() + self.shared_w2.numel()
|
||||
return total
|
||||
|
||||
|
||||
|
|
@ -316,8 +380,8 @@ class GPT(nn.Module):
|
|||
elif isinstance(block.mlp, MoEFeedForward):
|
||||
with torch.no_grad():
|
||||
block.mlp.routed_w2.zero_()
|
||||
for expert in block.mlp.shared_experts:
|
||||
torch.nn.init.zeros_(expert.c_proj.weight)
|
||||
if block.mlp.shared_w2 is not None:
|
||||
block.mlp.shared_w2.zero_()
|
||||
torch.nn.init.zeros_(block.attn.c_proj.weight)
|
||||
# init the rotary embeddings
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user