gpt2 backend

This commit is contained in:
Muheng 2026-01-06 05:05:51 +00:00
parent 0582d669b6
commit 74d94c923f
2 changed files with 94 additions and 6 deletions

View File

@ -14,7 +14,7 @@ source .venv/bin/activate
```bash
# export latest base checkpoint to hf-export/moe_std (gpt2 tokenizer)
uv run python -m nanochat.to_hf --source base --model-tag d20 --step 49000 --output hf-export/moe_std --tokenizer gpt2
uv run python -m nanochat.to_hf --source base --model-tag d00 --output hf-export/moe_legacy --tokenizer gpt2
# export latest SFT checkpoint (chat model, rustbpe tokenizer)
uv run python -m nanochat.to_hf --source sft --output hf-export/moe_sft --tokenizer cache
```
@ -38,17 +38,24 @@ uv run lm-eval run --model hf \
# commonsense benchmarks: HellaSwag, BoolQ, PIQA, Winograd-style
# (Winograd alternatives: winogrande (preferred) or wsc273 (classic WSC))
HF_ALLOW_CODE_EVAL=1 uv run lm-eval run --confirm_run_unsafe_code --model hf \
--model_args pretrained=hf-export/moe_std,trust_remote_code=True,tokenizer=hf-export/moe_std,max_length=1024 \
--model_args pretrained=hf-export/moe_sft_lr8,trust_remote_code=True,tokenizer=hf-export/moe_sft_lr8,max_length=1024 \
--tasks hellaswag,boolq,piqa,winogrande \
--batch_size 1 \
--log_samples \
--output_path lm_eval_sample_commonsense > commonsense.log 2>&1
--output_path lm_eval_sample_commonsense > sft_lr8_commonsense.log 2>&1
HF_ALLOW_CODE_EVAL=1 uv run lm-eval run --confirm_run_unsafe_code --model hf \
--model_args pretrained=hf-export/moe_sft_lr0.9,trust_remote_code=True,tokenizer=hf-export/moe_sft_lr0.9,max_length=1024 \
--tasks hellaswag,boolq,piqa,winogrande,arc_easy,arc_challenge,mmlu \
--batch_size 1 \
--log_samples \
--output_path lm_eval_sample_commonsense > moe_sft_lr0.9_all.log 2>&1
# arc_easy,arc_challenge,mmlu
HF_ALLOW_CODE_EVAL=1 uv run lm-eval run --confirm_run_unsafe_code --model hf \
--model_args pretrained=hf-export/moe_std,trust_remote_code=True,tokenizer=hf-export/moe_std,max_length=1024 \
--model_args pretrained=hf-export/moe_mid,trust_remote_code=True,tokenizer=hf-export/moe_mid,max_length=1024 \
--tasks arc_easy,arc_challenge,mmlu \
--batch_size 1 > moe_std_arc_mmlu.log 2>&1
--batch_size 1 > moe_mid_arc_mmlu.log 2>&1
# gsm8k, humaneval

View File

@ -183,6 +183,50 @@ def pad_token_embeddings(state_dict: dict, key: str, new_vocab: int):
state_dict[key] = torch.cat([tensor, pad], dim=0)
def infer_vocab_size(state_dict: dict) -> Optional[int]:
if "transformer.wte.weight" in state_dict:
return int(state_dict["transformer.wte.weight"].shape[0])
if "lm_head.weight" in state_dict:
return int(state_dict["lm_head.weight"].shape[0])
return None
def infer_bias(state_dict: dict) -> bool:
return any(k.endswith(".bias") for k in state_dict)
def infer_moe_layers(state_dict: dict) -> list[int]:
layers = set()
for key in state_dict:
if key.startswith("transformer.h.") and ".mlp.router." in key:
parts = key.split(".")
if len(parts) > 2 and parts[2].isdigit():
layers.add(int(parts[2]))
return sorted(layers)
def infer_n_exp(state_dict: dict) -> Optional[int]:
for key, tensor in state_dict.items():
if key.endswith("mlp.experts.c_fc"):
return int(tensor.shape[0])
return None
def infer_stride(n_layer: Optional[int], moe_layers: list[int]) -> Optional[int]:
if not moe_layers or n_layer is None:
return None
layer_set = set(moe_layers)
for stride in range(1, n_layer + 1):
expected = {i for i in range(n_layer) if i % stride == 0}
if expected == layer_set:
return stride
return None
def infer_use_noisy_top_k(state_dict: dict) -> bool:
return any(key.endswith("mlp.router.w_noise.weight") for key in state_dict)
def load_moe_checkpoint(
source: str,
model_tag: Optional[str],
@ -212,8 +256,45 @@ def load_moe_checkpoint(
meta = json.load(f)
cfg_kwargs = normalize_config(meta["model_config"])
tok_vocab = tokenizer.get_vocab_size()
cfg_updates = {}
state_vocab = infer_vocab_size(model_data)
cfg_vocab = cfg_kwargs.get("vocab_size")
if state_vocab is not None:
if cfg_vocab is None:
cfg_updates["vocab_size"] = state_vocab
cfg_vocab = state_vocab
elif cfg_vocab > state_vocab:
pad_token_embeddings(model_data, "transformer.wte.weight", cfg_vocab)
pad_token_embeddings(model_data, "lm_head.weight", cfg_vocab)
elif cfg_vocab < state_vocab:
cfg_updates["vocab_size"] = state_vocab
cfg_vocab = state_vocab
bias = infer_bias(model_data)
if cfg_kwargs.get("bias") != bias:
cfg_updates["bias"] = bias
moe_layers = infer_moe_layers(model_data)
if moe_layers:
n_exp = infer_n_exp(model_data)
if n_exp is not None and cfg_kwargs.get("n_exp") != n_exp:
cfg_updates["n_exp"] = n_exp
stride = infer_stride(cfg_kwargs.get("n_layer"), moe_layers)
if stride is not None and cfg_kwargs.get("stride") != stride:
cfg_updates["stride"] = stride
use_noisy = infer_use_noisy_top_k(model_data)
if cfg_kwargs.get("use_noisy_top_k") != use_noisy:
cfg_updates["use_noisy_top_k"] = use_noisy
else:
if cfg_kwargs.get("n_exp") not in (None, 1):
cfg_updates["n_exp"] = 1
if cfg_updates:
cfg_kwargs.update(cfg_updates)
meta["model_config"] = cfg_kwargs
print(f"[to_hf] Inferred config overrides: {cfg_updates}")
tok_vocab = tokenizer.get_vocab_size()
if cfg_vocab is None:
cfg_kwargs["vocab_size"] = tok_vocab
elif tok_vocab > cfg_vocab: