diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 58c1596..1f172b3 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -130,20 +130,6 @@ class MLP(nn.Module): return x -class ExpertFFN(nn.Module): - def __init__(self, input_dim, ffn_mult): - super().__init__() - hidden_dim = max(1, int(round(ffn_mult * input_dim))) - self.c_fc = nn.Linear(input_dim, hidden_dim, bias=False) - self.c_proj = nn.Linear(hidden_dim, input_dim, bias=False) - - def forward(self, x): - x = self.c_fc(x) - x = F.relu(x).square() - x = self.c_proj(x) - return x - - class MoEFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -158,9 +144,12 @@ class MoEFeedForward(nn.Module): self.hidden_dim = max(1, int(round(self.ffn_mult * self.model_dim))) self.routed_w1 = nn.Parameter(torch.empty(self.num_routed_experts, self.hidden_dim, self.model_dim)) self.routed_w2 = nn.Parameter(torch.empty(self.num_routed_experts, self.model_dim, self.hidden_dim)) - self.shared_experts = nn.ModuleList( - [ExpertFFN(self.model_dim, self.ffn_mult) for _ in range(self.num_shared_experts)] - ) + if self.num_shared_experts > 0: + self.shared_w1 = nn.Parameter(torch.empty(self.num_shared_experts, self.hidden_dim, self.model_dim)) + self.shared_w2 = nn.Parameter(torch.empty(self.num_shared_experts, self.model_dim, self.hidden_dim)) + else: + self.register_parameter("shared_w1", None) + self.register_parameter("shared_w2", None) self.register_buffer("router_bias", torch.zeros(self.num_routed_experts, dtype=torch.float32)) self.register_buffer("ema_load", torch.zeros(self.num_routed_experts, dtype=torch.float32)) self.balance_strength = 0.001 @@ -169,8 +158,40 @@ class MoEFeedForward(nn.Module): self._last_stats = None self.last_entropy = None self.last_load = None + self._token_idx_capacity = 0 + self._token_idx_device = None + self._token_idx_cache = None self._init_experts() + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + legacy_prefix = prefix + "shared_experts." + if any(key.startswith(legacy_prefix) for key in list(state_dict.keys())) and self.shared_w1 is not None: + for idx in range(self.num_shared_experts): + w1_key = f"{legacy_prefix}{idx}.c_fc.weight" + w2_key = f"{legacy_prefix}{idx}.c_proj.weight" + if w1_key in state_dict: + self.shared_w1.data[idx].copy_(state_dict.pop(w1_key)) + if w2_key in state_dict: + self.shared_w2.data[idx].copy_(state_dict.pop(w2_key)) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + def _balance_router(self, assignments, probs): if not self.training: self.last_load = None @@ -188,24 +209,28 @@ class MoEFeedForward(nn.Module): self.router_bias.clamp_(-self.bias_clamp, self.bias_clamp) entropy = (-probs * torch.log(probs + 1e-9)).sum(dim=-1).mean() load_cpu = load.detach().cpu() + entropy_cpu = entropy.detach().cpu() stats = { "tokens": probs.size(0), - "load_std": load_cpu.std(unbiased=False).item(), - "load_max": load_cpu.max().item(), - "load_min": load_cpu.min().item(), - "load_active_frac": (load_cpu > 0).float().mean().item(), - "load_imbalance": torch.abs(load_cpu - (1.0 / self.num_routed_experts)).mean().item(), + "load_std": load_cpu.std(unbiased=False), + "load_max": load_cpu.max(), + "load_min": load_cpu.min(), + "load_active_frac": (load_cpu > 0).float().mean(), + "load_imbalance": torch.abs(load_cpu - (1.0 / self.num_routed_experts)).mean(), + "entropy": entropy_cpu, } - self._last_stats = stats + self._last_stats = {key: (value.detach() if torch.is_tensor(value) else value) for key, value in stats.items()} self.last_load = load_cpu - self.last_entropy = entropy.item() + self.last_entropy = entropy_cpu def forward(self, x): B, T, C = x.shape x_flat = x.view(B * T, C) - router_logits = self.router(x_flat).float() - router_logits = router_logits + self.router_bias.to(router_logits.device) - probs = torch.softmax(router_logits, dim=-1) + router_logits = self.router(x_flat) + router_logits = router_logits + self.router_bias.to(router_logits.device, dtype=router_logits.dtype) + if x_flat.device.type in {"cuda", "mps"} and router_logits.dtype != torch.bfloat16: + router_logits = router_logits.to(torch.bfloat16) + probs = torch.softmax(router_logits, dim=-1).to(x_flat.dtype) topk_scores, topk_idx = torch.topk(probs, k=self.top_k, dim=-1) topk_scores = topk_scores / torch.clamp(topk_scores.sum(dim=-1, keepdim=True), min=1e-9) flat_assignments = topk_idx.reshape(-1) @@ -214,19 +239,17 @@ class MoEFeedForward(nn.Module): mixture = self._dispatch_batched(flat_assignments, x_flat, topk_scores) if self._last_stats is not None: - self._last_stats["gate_mean"] = topk_scores.mean().item() - self._last_stats["gate_std"] = topk_scores.std(unbiased=False).item() + self._last_stats["gate_mean"] = topk_scores.mean().detach().cpu() + self._last_stats["gate_std"] = topk_scores.std(unbiased=False).detach().cpu() confident = probs.max(dim=-1).values - self._last_stats["router_confidence"] = confident.mean().item() + self._last_stats["router_confidence"] = confident.mean().detach().cpu() self._last_stats["tokens"] = x_flat.size(0) - self._last_stats["entropy"] = self.last_entropy if self.last_entropy is not None else 0.0 + if self.last_entropy is not None: + self._last_stats["entropy"] = self.last_entropy - if self.num_shared_experts > 0: - shared_sum = None - for shared in self.shared_experts: - out = shared(x_flat) - shared_sum = out if shared_sum is None else shared_sum + out - mixture = mixture + shared_sum / self.num_shared_experts + shared = self._shared_forward(x_flat) + if shared is not None: + mixture = mixture + shared return mixture.view(B, T, C) @@ -236,7 +259,7 @@ class MoEFeedForward(nn.Module): if total == 0: return torch.zeros_like(x_flat) - token_idx = torch.arange(num_tokens, device=x_flat.device).repeat_interleave(self.top_k) + token_idx = self._get_token_indices(num_tokens, x_flat.device) selected_inputs = x_flat[token_idx] expert_ids = flat_assignments @@ -250,9 +273,43 @@ class MoEFeedForward(nn.Module): weights = topk_scores.unsqueeze(-1).to(routed.dtype) return torch.sum(routed * weights, dim=1) + def _get_token_indices(self, num_tokens, device): + required_tokens = num_tokens + if ( + self._token_idx_cache is None + or self._token_idx_device != device + or self._token_idx_capacity < required_tokens + ): + new_capacity = max(required_tokens, max(1, self._token_idx_capacity * 2)) + idx = torch.arange(new_capacity, device=device, dtype=torch.long).repeat_interleave(self.top_k) + self._token_idx_cache = idx + self._token_idx_capacity = new_capacity + self._token_idx_device = device + total = required_tokens * self.top_k + return self._token_idx_cache[:total] + + + + def _shared_forward(self, x_flat): + if self.num_shared_experts == 0: + return None + shared_w1 = self.shared_w1 + shared_w2 = self.shared_w2 + x_for_shared = x_flat.to(shared_w1.dtype) + hidden = torch.einsum("ehd,td->eth", shared_w1, x_for_shared) + hidden = F.relu(hidden).square() + shared = torch.einsum("eth,edh->etd", hidden, shared_w2) + return shared.mean(dim=0).to(x_flat.dtype) + def _init_experts(self): self._init_linear(self.routed_w1, fan_in=self.model_dim, fan_out=self.hidden_dim) self._init_linear(self.routed_w2, fan_in=self.hidden_dim, fan_out=self.model_dim) + if self.shared_w1 is not None: + for weight in self.shared_w1: + self._init_linear(weight, fan_in=self.model_dim, fan_out=self.hidden_dim) + if self.shared_w2 is not None: + for weight in self.shared_w2: + self._init_linear(weight, fan_in=self.hidden_dim, fan_out=self.model_dim) @staticmethod def _init_linear(weight, fan_in, fan_out): @@ -262,11 +319,18 @@ class MoEFeedForward(nn.Module): def get_stats(self): if self._last_stats is None: return None - return dict(self._last_stats) + stats = {} + for key, value in self._last_stats.items(): + if torch.is_tensor(value): + stats[key] = value.item() + else: + stats[key] = value + return stats def count_parameters(self): total = self.routed_w1.numel() + self.routed_w2.numel() - total += sum(p.numel() for expert in self.shared_experts for p in expert.parameters()) + if self.shared_w1 is not None: + total += self.shared_w1.numel() + self.shared_w2.numel() return total @@ -316,8 +380,8 @@ class GPT(nn.Module): elif isinstance(block.mlp, MoEFeedForward): with torch.no_grad(): block.mlp.routed_w2.zero_() - for expert in block.mlp.shared_experts: - torch.nn.init.zeros_(expert.c_proj.weight) + if block.mlp.shared_w2 is not None: + block.mlp.shared_w2.zero_() torch.nn.init.zeros_(block.attn.c_proj.weight) # init the rotary embeddings head_dim = self.config.n_embd // self.config.n_head