diff --git a/.gitignore b/.gitignore index c3befa6..49f0686 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,8 @@ hf-export/**/__pycache__/ agent.md lm_eval_sample_*/ *.pt -hf-export/**/* \ No newline at end of file +hf-export/**/* +lm_eval_pretrain/ +benchmark.md +d*.png +loss*.png \ No newline at end of file diff --git a/lm_eval.md b/lm_eval.md index 2e417af..57e44b0 100644 --- a/lm_eval.md +++ b/lm_eval.md @@ -14,7 +14,7 @@ source .venv/bin/activate ```bash # export latest base checkpoint to hf-export/moe_std (gpt2 tokenizer) uv run python -m nanochat.to_hf --source base --model-tag d20 --step 49000 --output hf-export/moe_std --tokenizer gpt2 -uv run python -m nanochat.to_hf --source base --model-tag d00 --output hf-export/moe_legacy --tokenizer gpt2 +uv run python -m nanochat.to_hf --source base --model-tag d24_e32_min_lr1e-05_max_lr0.0001 --output hf-export/d24_e32_min_lr1e-05_max_lr0.0001 --tokenizer cache # export latest SFT checkpoint (chat model, rustbpe tokenizer) uv run python -m nanochat.to_hf --source sft --output hf-export/moe_sft --tokenizer cache ``` @@ -45,11 +45,11 @@ HF_ALLOW_CODE_EVAL=1 uv run lm-eval run --confirm_run_unsafe_code --model hf \ --output_path lm_eval_sample_commonsense > sft_lr8_commonsense.log 2>&1 HF_ALLOW_CODE_EVAL=1 uv run lm-eval run --confirm_run_unsafe_code --model hf \ - --model_args pretrained=hf-export/moe_sft_lr0.9,trust_remote_code=True,tokenizer=hf-export/moe_sft_lr0.9,max_length=1024 \ + --model_args pretrained=/thullms/limh23/hf-export/{tag},trust_remote_code=True,tokenizer=/thullms/limh23/hf-export/{tag},max_length=1024 \ --tasks hellaswag,boolq,piqa,winogrande,arc_easy,arc_challenge,mmlu \ --batch_size 1 \ --log_samples \ - --output_path lm_eval_sample_commonsense > moe_sft_lr0.9_all.log 2>&1 + --output_path lm_eval_pretrain > benchmark_pretrain/moe_{tag}_all.log 2>&1 # arc_easy,arc_challenge,mmlu HF_ALLOW_CODE_EVAL=1 uv run lm-eval run --confirm_run_unsafe_code --model hf \ diff --git a/nanochat/gpt.py b/nanochat/gpt.py index e27fc4a..ee9e77e 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -585,7 +585,7 @@ class GPT(nn.Module): return optimizer - def forward(self, idx, targets=None): + def forward(self, idx, targets=None, return_full_logits: bool = False): device = idx.device b, t = idx.size() assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" @@ -599,22 +599,32 @@ class GPT(nn.Module): x = block(x) x = self.transformer.ln_f(x) - if targets is not None: - # if we are given some desired targets also calculate the loss + if targets is not None or return_full_logits: + # full logits are needed for eval tooling; loss is optional logits = self.lm_head(x) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + if targets is not None: + # if we are given some desired targets also calculate the loss + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) - # add the auxiliary load balancing loss and router z loss to the main loss - if self.config.n_exp > 1 and self.config.use_aux_loss: - loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss() - MANAGER.reset_aux_loss() - if self.config.n_exp > 1 and self.config.use_router_z_loss: - loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss() - MANAGER.reset_router_z_loss() + # add the auxiliary load balancing loss and router z loss to the main loss + if self.config.n_exp > 1 and self.config.use_aux_loss: + loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss() + MANAGER.reset_aux_loss() + if self.config.n_exp > 1 and self.config.use_router_z_loss: + loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss() + MANAGER.reset_router_z_loss() + else: + loss = None + if self.config.n_exp > 1: + MANAGER.reset_aux_loss() + MANAGER.reset_router_z_loss() else: # inference-time mini-optimization: only forward the lm_head on the very last position logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim loss = None + if self.config.n_exp > 1: + MANAGER.reset_aux_loss() + MANAGER.reset_router_z_loss() return logits, loss diff --git a/nanochat/to_hf.py b/nanochat/to_hf.py index c8704f9..58e2b1c 100644 --- a/nanochat/to_hf.py +++ b/nanochat/to_hf.py @@ -64,7 +64,7 @@ HF_CONFIG_FIELDS = { "dropout", "bias", "n_exp", - "top_k", + "router_top_k", "use_aux_loss", "use_router_z_loss", "use_noisy_top_k", @@ -320,7 +320,7 @@ class NanoChatMoEHFConfig(PretrainedConfig): dropout: float = 0.0, bias: bool = False, n_exp: int = 8, - top_k: int = 2, + router_top_k: int = 2, use_aux_loss: bool = True, use_router_z_loss: bool = True, use_noisy_top_k: bool = False, @@ -335,6 +335,10 @@ class NanoChatMoEHFConfig(PretrainedConfig): router_use_full_prec: bool = True, **kwargs, ): + if "top_k" in kwargs and "router_top_k" not in kwargs: + kwargs["router_top_k"] = kwargs.pop("top_k") + if "router_top_k" in kwargs: + router_top_k = kwargs.pop("router_top_k") kwargs.setdefault("tie_word_embeddings", False) super().__init__(**kwargs) self.block_size = block_size @@ -345,7 +349,7 @@ class NanoChatMoEHFConfig(PretrainedConfig): self.dropout = dropout self.bias = bias self.n_exp = n_exp - self.top_k = top_k + self.router_top_k = router_top_k self.use_aux_loss = use_aux_loss self.use_router_z_loss = use_router_z_loss self.use_noisy_top_k = use_noisy_top_k @@ -382,7 +386,7 @@ class NanoChatMoEHFForCausalLM(PreTrainedModel, GenerationMixin): dropout=config.dropout, bias=config.bias, n_exp=config.n_exp, - top_k=config.top_k, + top_k=config.router_top_k, use_aux_loss=config.use_aux_loss, use_router_z_loss=config.use_router_z_loss, use_noisy_top_k=config.use_noisy_top_k, @@ -421,10 +425,10 @@ class NanoChatMoEHFForCausalLM(PreTrainedModel, GenerationMixin): if input_ids is None: raise ValueError("input_ids must be provided") if labels is None: - logits, _ = self.model(input_ids) + logits, _ = self.model(input_ids, return_full_logits=True) loss = None else: - logits, loss = self.model(input_ids, targets=labels) + logits, loss = self.model(input_ids, targets=labels, return_full_logits=True) return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -459,7 +463,7 @@ class NanoChatMoEHFConfig(PretrainedConfig): dropout=0.0, bias=False, n_exp=8, - top_k=2, + router_top_k=2, use_aux_loss=True, use_router_z_loss=True, use_noisy_top_k=False, @@ -474,6 +478,10 @@ class NanoChatMoEHFConfig(PretrainedConfig): router_use_full_prec=True, **kwargs, ): + if "top_k" in kwargs and "router_top_k" not in kwargs: + kwargs["router_top_k"] = kwargs.pop("top_k") + if "router_top_k" in kwargs: + router_top_k = kwargs.pop("router_top_k") kwargs.setdefault("tie_word_embeddings", False) super().__init__(**kwargs) self.block_size = block_size @@ -484,7 +492,7 @@ class NanoChatMoEHFConfig(PretrainedConfig): self.dropout = dropout self.bias = bias self.n_exp = n_exp - self.top_k = top_k + self.router_top_k = router_top_k self.use_aux_loss = use_aux_loss self.use_router_z_loss = use_router_z_loss self.use_noisy_top_k = use_noisy_top_k @@ -532,7 +540,7 @@ class NanoChatMoEHFForCausalLM(PreTrainedModel, GenerationMixin): dropout=config.dropout, bias=config.bias, n_exp=config.n_exp, - top_k=config.top_k, + top_k=config.router_top_k, use_aux_loss=config.use_aux_loss, use_router_z_loss=config.use_router_z_loss, use_noisy_top_k=config.use_noisy_top_k, @@ -563,7 +571,7 @@ class NanoChatMoEHFForCausalLM(PreTrainedModel, GenerationMixin): def forward(self, input_ids=None, attention_mask=None, labels=None, past_key_values=None, **_): if input_ids is None: raise ValueError("input_ids must be provided") - logits, loss = self.model(input_ids, targets=labels) + logits, loss = self.model(input_ids, targets=labels, return_full_logits=True) return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -987,7 +995,7 @@ class GPT(nn.Module): )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - def forward(self, idx, targets=None): + def forward(self, idx, targets=None, return_full_logits=False): B, T = idx.size() assert T <= self.config.block_size pos = torch.arange(0, T, dtype=torch.long, device=idx.device) @@ -996,18 +1004,27 @@ class GPT(nn.Module): for block in self.transformer.h: x = block(x) x = self.transformer.ln_f(x) - if targets is not None: + if targets is not None or return_full_logits: logits = self.lm_head(x) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) - if self.config.n_exp > 1 and self.config.use_aux_loss: - loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss() - MANAGER.reset_aux_loss() - if self.config.n_exp > 1 and self.config.use_router_z_loss: - loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss() - MANAGER.reset_router_z_loss() + if targets is not None: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + if self.config.n_exp > 1 and self.config.use_aux_loss: + loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss() + MANAGER.reset_aux_loss() + if self.config.n_exp > 1 and self.config.use_router_z_loss: + loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss() + MANAGER.reset_router_z_loss() + else: + loss = None + if self.config.n_exp > 1: + MANAGER.reset_aux_loss() + MANAGER.reset_router_z_loss() else: logits = self.lm_head(x[:, [-1], :]) loss = None + if self.config.n_exp > 1: + MANAGER.reset_aux_loss() + MANAGER.reset_router_z_loss() return logits, loss """ @@ -1033,6 +1050,10 @@ def export_to_hf( tokenizer = load_export_tokenizer(tokenizer_mode) model, meta, cfg_kwargs = load_moe_checkpoint(source, model_tag, step, device, tokenizer) hf_kwargs = {k: v for k, v in cfg_kwargs.items() if k in HF_CONFIG_FIELDS} + if "router_top_k" in cfg_kwargs: + hf_kwargs["router_top_k"] = cfg_kwargs["router_top_k"] + elif "top_k" in cfg_kwargs: + hf_kwargs["router_top_k"] = cfg_kwargs["top_k"] hf_config = NanoChatMoEHFConfig(**hf_kwargs) hf_model = NanoChatMoEHFForCausalLM(hf_config) hf_model.model.load_state_dict(model.state_dict(), strict=True)