Refactor MoEFeedForward class and remove ExpertFFN implementation

- Removed the ExpertFFN class to streamline the MoEFeedForward implementation.
- Introduced shared weights for experts, enhancing flexibility in expert configurations.
- Updated the forward method to utilize shared weights and improved tensor handling for statistics.
- Added methods for managing token indices and shared forward computations, optimizing performance.
- Enhanced state loading to accommodate new shared weights structure.
This commit is contained in:
William Thurston 2025-11-11 21:16:33 -08:00
parent 25d2573f47
commit d7c62c2cfe

View File

@ -130,20 +130,6 @@ class MLP(nn.Module):
return x
class ExpertFFN(nn.Module):
def __init__(self, input_dim, ffn_mult):
super().__init__()
hidden_dim = max(1, int(round(ffn_mult * input_dim)))
self.c_fc = nn.Linear(input_dim, hidden_dim, bias=False)
self.c_proj = nn.Linear(hidden_dim, input_dim, bias=False)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x
class MoEFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
@ -158,9 +144,12 @@ class MoEFeedForward(nn.Module):
self.hidden_dim = max(1, int(round(self.ffn_mult * self.model_dim)))
self.routed_w1 = nn.Parameter(torch.empty(self.num_routed_experts, self.hidden_dim, self.model_dim))
self.routed_w2 = nn.Parameter(torch.empty(self.num_routed_experts, self.model_dim, self.hidden_dim))
self.shared_experts = nn.ModuleList(
[ExpertFFN(self.model_dim, self.ffn_mult) for _ in range(self.num_shared_experts)]
)
if self.num_shared_experts > 0:
self.shared_w1 = nn.Parameter(torch.empty(self.num_shared_experts, self.hidden_dim, self.model_dim))
self.shared_w2 = nn.Parameter(torch.empty(self.num_shared_experts, self.model_dim, self.hidden_dim))
else:
self.register_parameter("shared_w1", None)
self.register_parameter("shared_w2", None)
self.register_buffer("router_bias", torch.zeros(self.num_routed_experts, dtype=torch.float32))
self.register_buffer("ema_load", torch.zeros(self.num_routed_experts, dtype=torch.float32))
self.balance_strength = 0.001
@ -169,8 +158,40 @@ class MoEFeedForward(nn.Module):
self._last_stats = None
self.last_entropy = None
self.last_load = None
self._token_idx_capacity = 0
self._token_idx_device = None
self._token_idx_cache = None
self._init_experts()
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
legacy_prefix = prefix + "shared_experts."
if any(key.startswith(legacy_prefix) for key in list(state_dict.keys())) and self.shared_w1 is not None:
for idx in range(self.num_shared_experts):
w1_key = f"{legacy_prefix}{idx}.c_fc.weight"
w2_key = f"{legacy_prefix}{idx}.c_proj.weight"
if w1_key in state_dict:
self.shared_w1.data[idx].copy_(state_dict.pop(w1_key))
if w2_key in state_dict:
self.shared_w2.data[idx].copy_(state_dict.pop(w2_key))
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def _balance_router(self, assignments, probs):
if not self.training:
self.last_load = None
@ -188,24 +209,28 @@ class MoEFeedForward(nn.Module):
self.router_bias.clamp_(-self.bias_clamp, self.bias_clamp)
entropy = (-probs * torch.log(probs + 1e-9)).sum(dim=-1).mean()
load_cpu = load.detach().cpu()
entropy_cpu = entropy.detach().cpu()
stats = {
"tokens": probs.size(0),
"load_std": load_cpu.std(unbiased=False).item(),
"load_max": load_cpu.max().item(),
"load_min": load_cpu.min().item(),
"load_active_frac": (load_cpu > 0).float().mean().item(),
"load_imbalance": torch.abs(load_cpu - (1.0 / self.num_routed_experts)).mean().item(),
"load_std": load_cpu.std(unbiased=False),
"load_max": load_cpu.max(),
"load_min": load_cpu.min(),
"load_active_frac": (load_cpu > 0).float().mean(),
"load_imbalance": torch.abs(load_cpu - (1.0 / self.num_routed_experts)).mean(),
"entropy": entropy_cpu,
}
self._last_stats = stats
self._last_stats = {key: (value.detach() if torch.is_tensor(value) else value) for key, value in stats.items()}
self.last_load = load_cpu
self.last_entropy = entropy.item()
self.last_entropy = entropy_cpu
def forward(self, x):
B, T, C = x.shape
x_flat = x.view(B * T, C)
router_logits = self.router(x_flat).float()
router_logits = router_logits + self.router_bias.to(router_logits.device)
probs = torch.softmax(router_logits, dim=-1)
router_logits = self.router(x_flat)
router_logits = router_logits + self.router_bias.to(router_logits.device, dtype=router_logits.dtype)
if x_flat.device.type in {"cuda", "mps"} and router_logits.dtype != torch.bfloat16:
router_logits = router_logits.to(torch.bfloat16)
probs = torch.softmax(router_logits, dim=-1).to(x_flat.dtype)
topk_scores, topk_idx = torch.topk(probs, k=self.top_k, dim=-1)
topk_scores = topk_scores / torch.clamp(topk_scores.sum(dim=-1, keepdim=True), min=1e-9)
flat_assignments = topk_idx.reshape(-1)
@ -214,19 +239,17 @@ class MoEFeedForward(nn.Module):
mixture = self._dispatch_batched(flat_assignments, x_flat, topk_scores)
if self._last_stats is not None:
self._last_stats["gate_mean"] = topk_scores.mean().item()
self._last_stats["gate_std"] = topk_scores.std(unbiased=False).item()
self._last_stats["gate_mean"] = topk_scores.mean().detach().cpu()
self._last_stats["gate_std"] = topk_scores.std(unbiased=False).detach().cpu()
confident = probs.max(dim=-1).values
self._last_stats["router_confidence"] = confident.mean().item()
self._last_stats["router_confidence"] = confident.mean().detach().cpu()
self._last_stats["tokens"] = x_flat.size(0)
self._last_stats["entropy"] = self.last_entropy if self.last_entropy is not None else 0.0
if self.last_entropy is not None:
self._last_stats["entropy"] = self.last_entropy
if self.num_shared_experts > 0:
shared_sum = None
for shared in self.shared_experts:
out = shared(x_flat)
shared_sum = out if shared_sum is None else shared_sum + out
mixture = mixture + shared_sum / self.num_shared_experts
shared = self._shared_forward(x_flat)
if shared is not None:
mixture = mixture + shared
return mixture.view(B, T, C)
@ -236,7 +259,7 @@ class MoEFeedForward(nn.Module):
if total == 0:
return torch.zeros_like(x_flat)
token_idx = torch.arange(num_tokens, device=x_flat.device).repeat_interleave(self.top_k)
token_idx = self._get_token_indices(num_tokens, x_flat.device)
selected_inputs = x_flat[token_idx]
expert_ids = flat_assignments
@ -250,9 +273,43 @@ class MoEFeedForward(nn.Module):
weights = topk_scores.unsqueeze(-1).to(routed.dtype)
return torch.sum(routed * weights, dim=1)
def _get_token_indices(self, num_tokens, device):
required_tokens = num_tokens
if (
self._token_idx_cache is None
or self._token_idx_device != device
or self._token_idx_capacity < required_tokens
):
new_capacity = max(required_tokens, max(1, self._token_idx_capacity * 2))
idx = torch.arange(new_capacity, device=device, dtype=torch.long).repeat_interleave(self.top_k)
self._token_idx_cache = idx
self._token_idx_capacity = new_capacity
self._token_idx_device = device
total = required_tokens * self.top_k
return self._token_idx_cache[:total]
def _shared_forward(self, x_flat):
if self.num_shared_experts == 0:
return None
shared_w1 = self.shared_w1
shared_w2 = self.shared_w2
x_for_shared = x_flat.to(shared_w1.dtype)
hidden = torch.einsum("ehd,td->eth", shared_w1, x_for_shared)
hidden = F.relu(hidden).square()
shared = torch.einsum("eth,edh->etd", hidden, shared_w2)
return shared.mean(dim=0).to(x_flat.dtype)
def _init_experts(self):
self._init_linear(self.routed_w1, fan_in=self.model_dim, fan_out=self.hidden_dim)
self._init_linear(self.routed_w2, fan_in=self.hidden_dim, fan_out=self.model_dim)
if self.shared_w1 is not None:
for weight in self.shared_w1:
self._init_linear(weight, fan_in=self.model_dim, fan_out=self.hidden_dim)
if self.shared_w2 is not None:
for weight in self.shared_w2:
self._init_linear(weight, fan_in=self.hidden_dim, fan_out=self.model_dim)
@staticmethod
def _init_linear(weight, fan_in, fan_out):
@ -262,11 +319,18 @@ class MoEFeedForward(nn.Module):
def get_stats(self):
if self._last_stats is None:
return None
return dict(self._last_stats)
stats = {}
for key, value in self._last_stats.items():
if torch.is_tensor(value):
stats[key] = value.item()
else:
stats[key] = value
return stats
def count_parameters(self):
total = self.routed_w1.numel() + self.routed_w2.numel()
total += sum(p.numel() for expert in self.shared_experts for p in expert.parameters())
if self.shared_w1 is not None:
total += self.shared_w1.numel() + self.shared_w2.numel()
return total
@ -316,8 +380,8 @@ class GPT(nn.Module):
elif isinstance(block.mlp, MoEFeedForward):
with torch.no_grad():
block.mlp.routed_w2.zero_()
for expert in block.mlp.shared_experts:
torch.nn.init.zeros_(expert.c_proj.weight)
if block.mlp.shared_w2 is not None:
block.mlp.shared_w2.zero_()
torch.nn.init.zeros_(block.attn.c_proj.weight)
# init the rotary embeddings
head_dim = self.config.n_embd // self.config.n_head