debug to_hf

This commit is contained in:
Muheng 2026-01-15 12:34:27 +00:00
parent 1dc944f734
commit 9b05e7c625
4 changed files with 69 additions and 34 deletions

6
.gitignore vendored
View File

@ -18,4 +18,8 @@ hf-export/**/__pycache__/
agent.md
lm_eval_sample_*/
*.pt
hf-export/**/*
hf-export/**/*
lm_eval_pretrain/
benchmark.md
d*.png
loss*.png

View File

@ -14,7 +14,7 @@ source .venv/bin/activate
```bash
# export latest base checkpoint to hf-export/moe_std (gpt2 tokenizer)
uv run python -m nanochat.to_hf --source base --model-tag d20 --step 49000 --output hf-export/moe_std --tokenizer gpt2
uv run python -m nanochat.to_hf --source base --model-tag d00 --output hf-export/moe_legacy --tokenizer gpt2
uv run python -m nanochat.to_hf --source base --model-tag d24_e32_min_lr1e-05_max_lr0.0001 --output hf-export/d24_e32_min_lr1e-05_max_lr0.0001 --tokenizer cache
# export latest SFT checkpoint (chat model, rustbpe tokenizer)
uv run python -m nanochat.to_hf --source sft --output hf-export/moe_sft --tokenizer cache
```
@ -45,11 +45,11 @@ HF_ALLOW_CODE_EVAL=1 uv run lm-eval run --confirm_run_unsafe_code --model hf \
--output_path lm_eval_sample_commonsense > sft_lr8_commonsense.log 2>&1
HF_ALLOW_CODE_EVAL=1 uv run lm-eval run --confirm_run_unsafe_code --model hf \
--model_args pretrained=hf-export/moe_sft_lr0.9,trust_remote_code=True,tokenizer=hf-export/moe_sft_lr0.9,max_length=1024 \
--model_args pretrained=/thullms/limh23/hf-export/{tag},trust_remote_code=True,tokenizer=/thullms/limh23/hf-export/{tag},max_length=1024 \
--tasks hellaswag,boolq,piqa,winogrande,arc_easy,arc_challenge,mmlu \
--batch_size 1 \
--log_samples \
--output_path lm_eval_sample_commonsense > moe_sft_lr0.9_all.log 2>&1
--output_path lm_eval_pretrain > benchmark_pretrain/moe_{tag}_all.log 2>&1
# arc_easy,arc_challenge,mmlu
HF_ALLOW_CODE_EVAL=1 uv run lm-eval run --confirm_run_unsafe_code --model hf \

View File

@ -585,7 +585,7 @@ class GPT(nn.Module):
return optimizer
def forward(self, idx, targets=None):
def forward(self, idx, targets=None, return_full_logits: bool = False):
device = idx.device
b, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
@ -599,22 +599,32 @@ class GPT(nn.Module):
x = block(x)
x = self.transformer.ln_f(x)
if targets is not None:
# if we are given some desired targets also calculate the loss
if targets is not None or return_full_logits:
# full logits are needed for eval tooling; loss is optional
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
if targets is not None:
# if we are given some desired targets also calculate the loss
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
# add the auxiliary load balancing loss and router z loss to the main loss
if self.config.n_exp > 1 and self.config.use_aux_loss:
loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss()
MANAGER.reset_aux_loss()
if self.config.n_exp > 1 and self.config.use_router_z_loss:
loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss()
MANAGER.reset_router_z_loss()
# add the auxiliary load balancing loss and router z loss to the main loss
if self.config.n_exp > 1 and self.config.use_aux_loss:
loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss()
MANAGER.reset_aux_loss()
if self.config.n_exp > 1 and self.config.use_router_z_loss:
loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss()
MANAGER.reset_router_z_loss()
else:
loss = None
if self.config.n_exp > 1:
MANAGER.reset_aux_loss()
MANAGER.reset_router_z_loss()
else:
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None
if self.config.n_exp > 1:
MANAGER.reset_aux_loss()
MANAGER.reset_router_z_loss()
return logits, loss

View File

@ -64,7 +64,7 @@ HF_CONFIG_FIELDS = {
"dropout",
"bias",
"n_exp",
"top_k",
"router_top_k",
"use_aux_loss",
"use_router_z_loss",
"use_noisy_top_k",
@ -320,7 +320,7 @@ class NanoChatMoEHFConfig(PretrainedConfig):
dropout: float = 0.0,
bias: bool = False,
n_exp: int = 8,
top_k: int = 2,
router_top_k: int = 2,
use_aux_loss: bool = True,
use_router_z_loss: bool = True,
use_noisy_top_k: bool = False,
@ -335,6 +335,10 @@ class NanoChatMoEHFConfig(PretrainedConfig):
router_use_full_prec: bool = True,
**kwargs,
):
if "top_k" in kwargs and "router_top_k" not in kwargs:
kwargs["router_top_k"] = kwargs.pop("top_k")
if "router_top_k" in kwargs:
router_top_k = kwargs.pop("router_top_k")
kwargs.setdefault("tie_word_embeddings", False)
super().__init__(**kwargs)
self.block_size = block_size
@ -345,7 +349,7 @@ class NanoChatMoEHFConfig(PretrainedConfig):
self.dropout = dropout
self.bias = bias
self.n_exp = n_exp
self.top_k = top_k
self.router_top_k = router_top_k
self.use_aux_loss = use_aux_loss
self.use_router_z_loss = use_router_z_loss
self.use_noisy_top_k = use_noisy_top_k
@ -382,7 +386,7 @@ class NanoChatMoEHFForCausalLM(PreTrainedModel, GenerationMixin):
dropout=config.dropout,
bias=config.bias,
n_exp=config.n_exp,
top_k=config.top_k,
top_k=config.router_top_k,
use_aux_loss=config.use_aux_loss,
use_router_z_loss=config.use_router_z_loss,
use_noisy_top_k=config.use_noisy_top_k,
@ -421,10 +425,10 @@ class NanoChatMoEHFForCausalLM(PreTrainedModel, GenerationMixin):
if input_ids is None:
raise ValueError("input_ids must be provided")
if labels is None:
logits, _ = self.model(input_ids)
logits, _ = self.model(input_ids, return_full_logits=True)
loss = None
else:
logits, loss = self.model(input_ids, targets=labels)
logits, loss = self.model(input_ids, targets=labels, return_full_logits=True)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
@ -459,7 +463,7 @@ class NanoChatMoEHFConfig(PretrainedConfig):
dropout=0.0,
bias=False,
n_exp=8,
top_k=2,
router_top_k=2,
use_aux_loss=True,
use_router_z_loss=True,
use_noisy_top_k=False,
@ -474,6 +478,10 @@ class NanoChatMoEHFConfig(PretrainedConfig):
router_use_full_prec=True,
**kwargs,
):
if "top_k" in kwargs and "router_top_k" not in kwargs:
kwargs["router_top_k"] = kwargs.pop("top_k")
if "router_top_k" in kwargs:
router_top_k = kwargs.pop("router_top_k")
kwargs.setdefault("tie_word_embeddings", False)
super().__init__(**kwargs)
self.block_size = block_size
@ -484,7 +492,7 @@ class NanoChatMoEHFConfig(PretrainedConfig):
self.dropout = dropout
self.bias = bias
self.n_exp = n_exp
self.top_k = top_k
self.router_top_k = router_top_k
self.use_aux_loss = use_aux_loss
self.use_router_z_loss = use_router_z_loss
self.use_noisy_top_k = use_noisy_top_k
@ -532,7 +540,7 @@ class NanoChatMoEHFForCausalLM(PreTrainedModel, GenerationMixin):
dropout=config.dropout,
bias=config.bias,
n_exp=config.n_exp,
top_k=config.top_k,
top_k=config.router_top_k,
use_aux_loss=config.use_aux_loss,
use_router_z_loss=config.use_router_z_loss,
use_noisy_top_k=config.use_noisy_top_k,
@ -563,7 +571,7 @@ class NanoChatMoEHFForCausalLM(PreTrainedModel, GenerationMixin):
def forward(self, input_ids=None, attention_mask=None, labels=None, past_key_values=None, **_):
if input_ids is None:
raise ValueError("input_ids must be provided")
logits, loss = self.model(input_ids, targets=labels)
logits, loss = self.model(input_ids, targets=labels, return_full_logits=True)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
@ -987,7 +995,7 @@ class GPT(nn.Module):
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
def forward(self, idx, targets=None):
def forward(self, idx, targets=None, return_full_logits=False):
B, T = idx.size()
assert T <= self.config.block_size
pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
@ -996,18 +1004,27 @@ class GPT(nn.Module):
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
if targets is not None:
if targets is not None or return_full_logits:
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
if self.config.n_exp > 1 and self.config.use_aux_loss:
loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss()
MANAGER.reset_aux_loss()
if self.config.n_exp > 1 and self.config.use_router_z_loss:
loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss()
MANAGER.reset_router_z_loss()
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
if self.config.n_exp > 1 and self.config.use_aux_loss:
loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss()
MANAGER.reset_aux_loss()
if self.config.n_exp > 1 and self.config.use_router_z_loss:
loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss()
MANAGER.reset_router_z_loss()
else:
loss = None
if self.config.n_exp > 1:
MANAGER.reset_aux_loss()
MANAGER.reset_router_z_loss()
else:
logits = self.lm_head(x[:, [-1], :])
loss = None
if self.config.n_exp > 1:
MANAGER.reset_aux_loss()
MANAGER.reset_router_z_loss()
return logits, loss
"""
@ -1033,6 +1050,10 @@ def export_to_hf(
tokenizer = load_export_tokenizer(tokenizer_mode)
model, meta, cfg_kwargs = load_moe_checkpoint(source, model_tag, step, device, tokenizer)
hf_kwargs = {k: v for k, v in cfg_kwargs.items() if k in HF_CONFIG_FIELDS}
if "router_top_k" in cfg_kwargs:
hf_kwargs["router_top_k"] = cfg_kwargs["router_top_k"]
elif "top_k" in cfg_kwargs:
hf_kwargs["router_top_k"] = cfg_kwargs["top_k"]
hf_config = NanoChatMoEHFConfig(**hf_kwargs)
hf_model = NanoChatMoEHFForCausalLM(hf_config)
hf_model.model.load_state_dict(model.state_dict(), strict=True)