mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-29 20:40:23 +00:00
gpt2 backend
This commit is contained in:
parent
0582d669b6
commit
74d94c923f
17
lm_eval.md
17
lm_eval.md
|
|
@ -14,7 +14,7 @@ source .venv/bin/activate
|
|||
```bash
|
||||
# export latest base checkpoint to hf-export/moe_std (gpt2 tokenizer)
|
||||
uv run python -m nanochat.to_hf --source base --model-tag d20 --step 49000 --output hf-export/moe_std --tokenizer gpt2
|
||||
|
||||
uv run python -m nanochat.to_hf --source base --model-tag d00 --output hf-export/moe_legacy --tokenizer gpt2
|
||||
# export latest SFT checkpoint (chat model, rustbpe tokenizer)
|
||||
uv run python -m nanochat.to_hf --source sft --output hf-export/moe_sft --tokenizer cache
|
||||
```
|
||||
|
|
@ -38,17 +38,24 @@ uv run lm-eval run --model hf \
|
|||
# commonsense benchmarks: HellaSwag, BoolQ, PIQA, Winograd-style
|
||||
# (Winograd alternatives: winogrande (preferred) or wsc273 (classic WSC))
|
||||
HF_ALLOW_CODE_EVAL=1 uv run lm-eval run --confirm_run_unsafe_code --model hf \
|
||||
--model_args pretrained=hf-export/moe_std,trust_remote_code=True,tokenizer=hf-export/moe_std,max_length=1024 \
|
||||
--model_args pretrained=hf-export/moe_sft_lr8,trust_remote_code=True,tokenizer=hf-export/moe_sft_lr8,max_length=1024 \
|
||||
--tasks hellaswag,boolq,piqa,winogrande \
|
||||
--batch_size 1 \
|
||||
--log_samples \
|
||||
--output_path lm_eval_sample_commonsense > commonsense.log 2>&1
|
||||
--output_path lm_eval_sample_commonsense > sft_lr8_commonsense.log 2>&1
|
||||
|
||||
HF_ALLOW_CODE_EVAL=1 uv run lm-eval run --confirm_run_unsafe_code --model hf \
|
||||
--model_args pretrained=hf-export/moe_sft_lr0.9,trust_remote_code=True,tokenizer=hf-export/moe_sft_lr0.9,max_length=1024 \
|
||||
--tasks hellaswag,boolq,piqa,winogrande,arc_easy,arc_challenge,mmlu \
|
||||
--batch_size 1 \
|
||||
--log_samples \
|
||||
--output_path lm_eval_sample_commonsense > moe_sft_lr0.9_all.log 2>&1
|
||||
|
||||
# arc_easy,arc_challenge,mmlu
|
||||
HF_ALLOW_CODE_EVAL=1 uv run lm-eval run --confirm_run_unsafe_code --model hf \
|
||||
--model_args pretrained=hf-export/moe_std,trust_remote_code=True,tokenizer=hf-export/moe_std,max_length=1024 \
|
||||
--model_args pretrained=hf-export/moe_mid,trust_remote_code=True,tokenizer=hf-export/moe_mid,max_length=1024 \
|
||||
--tasks arc_easy,arc_challenge,mmlu \
|
||||
--batch_size 1 > moe_std_arc_mmlu.log 2>&1
|
||||
--batch_size 1 > moe_mid_arc_mmlu.log 2>&1
|
||||
|
||||
# gsm8k, humaneval
|
||||
|
||||
|
|
|
|||
|
|
@ -183,6 +183,50 @@ def pad_token_embeddings(state_dict: dict, key: str, new_vocab: int):
|
|||
state_dict[key] = torch.cat([tensor, pad], dim=0)
|
||||
|
||||
|
||||
def infer_vocab_size(state_dict: dict) -> Optional[int]:
|
||||
if "transformer.wte.weight" in state_dict:
|
||||
return int(state_dict["transformer.wte.weight"].shape[0])
|
||||
if "lm_head.weight" in state_dict:
|
||||
return int(state_dict["lm_head.weight"].shape[0])
|
||||
return None
|
||||
|
||||
|
||||
def infer_bias(state_dict: dict) -> bool:
|
||||
return any(k.endswith(".bias") for k in state_dict)
|
||||
|
||||
|
||||
def infer_moe_layers(state_dict: dict) -> list[int]:
|
||||
layers = set()
|
||||
for key in state_dict:
|
||||
if key.startswith("transformer.h.") and ".mlp.router." in key:
|
||||
parts = key.split(".")
|
||||
if len(parts) > 2 and parts[2].isdigit():
|
||||
layers.add(int(parts[2]))
|
||||
return sorted(layers)
|
||||
|
||||
|
||||
def infer_n_exp(state_dict: dict) -> Optional[int]:
|
||||
for key, tensor in state_dict.items():
|
||||
if key.endswith("mlp.experts.c_fc"):
|
||||
return int(tensor.shape[0])
|
||||
return None
|
||||
|
||||
|
||||
def infer_stride(n_layer: Optional[int], moe_layers: list[int]) -> Optional[int]:
|
||||
if not moe_layers or n_layer is None:
|
||||
return None
|
||||
layer_set = set(moe_layers)
|
||||
for stride in range(1, n_layer + 1):
|
||||
expected = {i for i in range(n_layer) if i % stride == 0}
|
||||
if expected == layer_set:
|
||||
return stride
|
||||
return None
|
||||
|
||||
|
||||
def infer_use_noisy_top_k(state_dict: dict) -> bool:
|
||||
return any(key.endswith("mlp.router.w_noise.weight") for key in state_dict)
|
||||
|
||||
|
||||
def load_moe_checkpoint(
|
||||
source: str,
|
||||
model_tag: Optional[str],
|
||||
|
|
@ -212,8 +256,45 @@ def load_moe_checkpoint(
|
|||
meta = json.load(f)
|
||||
cfg_kwargs = normalize_config(meta["model_config"])
|
||||
|
||||
tok_vocab = tokenizer.get_vocab_size()
|
||||
cfg_updates = {}
|
||||
state_vocab = infer_vocab_size(model_data)
|
||||
cfg_vocab = cfg_kwargs.get("vocab_size")
|
||||
if state_vocab is not None:
|
||||
if cfg_vocab is None:
|
||||
cfg_updates["vocab_size"] = state_vocab
|
||||
cfg_vocab = state_vocab
|
||||
elif cfg_vocab > state_vocab:
|
||||
pad_token_embeddings(model_data, "transformer.wte.weight", cfg_vocab)
|
||||
pad_token_embeddings(model_data, "lm_head.weight", cfg_vocab)
|
||||
elif cfg_vocab < state_vocab:
|
||||
cfg_updates["vocab_size"] = state_vocab
|
||||
cfg_vocab = state_vocab
|
||||
|
||||
bias = infer_bias(model_data)
|
||||
if cfg_kwargs.get("bias") != bias:
|
||||
cfg_updates["bias"] = bias
|
||||
|
||||
moe_layers = infer_moe_layers(model_data)
|
||||
if moe_layers:
|
||||
n_exp = infer_n_exp(model_data)
|
||||
if n_exp is not None and cfg_kwargs.get("n_exp") != n_exp:
|
||||
cfg_updates["n_exp"] = n_exp
|
||||
stride = infer_stride(cfg_kwargs.get("n_layer"), moe_layers)
|
||||
if stride is not None and cfg_kwargs.get("stride") != stride:
|
||||
cfg_updates["stride"] = stride
|
||||
use_noisy = infer_use_noisy_top_k(model_data)
|
||||
if cfg_kwargs.get("use_noisy_top_k") != use_noisy:
|
||||
cfg_updates["use_noisy_top_k"] = use_noisy
|
||||
else:
|
||||
if cfg_kwargs.get("n_exp") not in (None, 1):
|
||||
cfg_updates["n_exp"] = 1
|
||||
|
||||
if cfg_updates:
|
||||
cfg_kwargs.update(cfg_updates)
|
||||
meta["model_config"] = cfg_kwargs
|
||||
print(f"[to_hf] Inferred config overrides: {cfg_updates}")
|
||||
|
||||
tok_vocab = tokenizer.get_vocab_size()
|
||||
if cfg_vocab is None:
|
||||
cfg_kwargs["vocab_size"] = tok_vocab
|
||||
elif tok_vocab > cfg_vocab:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user