mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-23 20:04:22 +00:00
debug hf inference
This commit is contained in:
parent
6095f82fdd
commit
6fb1d64864
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -5,4 +5,5 @@ rustbpe/target/
|
|||
dev-ignore/
|
||||
report.md
|
||||
eval_bundle/
|
||||
hf-export/*
|
||||
hf-export/*
|
||||
*.log
|
||||
|
|
@ -29,13 +29,13 @@ Example runs:
|
|||
```bash
|
||||
# Single task (MMLU)
|
||||
uv run lm-eval run --model hf \
|
||||
--model_args pretrained=hf-export/sft \
|
||||
--model_args pretrained=hf-export/sft,trust_remote_code=True \
|
||||
--tasks mmlu \
|
||||
--batch_size 1
|
||||
|
||||
# A small suite similar to nanochat chat_eval coverage
|
||||
uv run lm-eval run --model hf \
|
||||
--model_args pretrained=hf-export/sft \
|
||||
--model_args pretrained=hf-export/sft,trust_remote_code=True \
|
||||
--tasks arc_easy,arc_challenge,gsm8k,mmlu,humaneval \
|
||||
--batch_size 1
|
||||
```
|
||||
|
|
|
|||
|
|
@ -101,7 +101,6 @@ def find_largest_model(checkpoint_dir):
|
|||
# 1) normally all model tags are of the form d<number>, try that first:
|
||||
candidates = []
|
||||
for model_tag in model_tags:
|
||||
print(model_tag)
|
||||
match = re.match(r"d(\d+)", model_tag)
|
||||
if match:
|
||||
model_depth = int(match.group(1))
|
||||
|
|
|
|||
|
|
@ -20,10 +20,12 @@ import torch.nn.functional as F
|
|||
try:
|
||||
from transformers import PreTrainedModel, PretrainedConfig
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.generation.utils import GenerationMixin
|
||||
except ImportError as exc:
|
||||
raise SystemExit(
|
||||
"transformers is required for HF export. Run `uv sync` (with the hf extra) first."
|
||||
) from exc
|
||||
|
||||
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
|
|
@ -55,7 +57,7 @@ class NanoChatHFConfig(PretrainedConfig):
|
|||
self.n_embd = n_embd
|
||||
|
||||
|
||||
class NanoChatHFForCausalLM(PreTrainedModel):
|
||||
class NanoChatHFForCausalLM(PreTrainedModel, GenerationMixin):
|
||||
config_class = NanoChatHFConfig
|
||||
|
||||
def __init__(self, config: NanoChatHFConfig):
|
||||
|
|
@ -128,6 +130,76 @@ def copy_tokenizer_files(output_dir: str):
|
|||
shutil.copy2(src, dst)
|
||||
print(f"[to_hf] Copied tokenizer files from {tokenizer_dir} to {output_dir}")
|
||||
|
||||
def write_hf_code(output_dir: str):
|
||||
cfg_py = r'''
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
class NanoChatHFConfig(PretrainedConfig):
|
||||
model_type = "nanochat"
|
||||
def __init__(self, sequence_len=1024, vocab_size=50304, n_layer=12, n_head=6, n_kv_head=6, n_embd=768, **kwargs):
|
||||
kwargs.setdefault("tie_word_embeddings", False)
|
||||
super().__init__(**kwargs)
|
||||
self.sequence_len = sequence_len
|
||||
self.vocab_size = vocab_size
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.n_kv_head = n_kv_head
|
||||
self.n_embd = n_embd
|
||||
'''
|
||||
mdl_py = r'''
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.generation.utils import GenerationMixin
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from configuration_nanochat import NanoChatHFConfig
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
|
||||
class NanoChatHFForCausalLM(PreTrainedModel, GenerationMixin):
|
||||
config_class = NanoChatHFConfig
|
||||
|
||||
def __init__(self, config: NanoChatHFConfig):
|
||||
super().__init__(config)
|
||||
gpt_cfg = GPTConfig(
|
||||
sequence_len=config.sequence_len,
|
||||
vocab_size=config.vocab_size,
|
||||
n_layer=config.n_layer,
|
||||
n_head=config.n_head,
|
||||
n_kv_head=config.n_kv_head,
|
||||
n_embd=config.n_embd,
|
||||
)
|
||||
self.model = GPT(gpt_cfg)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.transformer.wte
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.transformer.wte = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.model.lm_head
|
||||
|
||||
def tie_weights(self):
|
||||
return
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, labels=None, past_key_values=None, **_):
|
||||
if input_ids is None:
|
||||
raise ValueError("input_ids must be provided")
|
||||
logits = self.model(input_ids)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1)
|
||||
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||
return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
|
||||
'''
|
||||
with open(os.path.join(output_dir, "configuration_nanochat.py"), "w") as f:
|
||||
f.write(cfg_py)
|
||||
with open(os.path.join(output_dir, "modeling_nanochat.py"), "w") as f:
|
||||
f.write(mdl_py)
|
||||
|
||||
|
||||
def export_to_hf(source: str, output_dir: str, model_tag: Optional[str], step: Optional[int]):
|
||||
device = torch.device("cpu")
|
||||
|
|
@ -138,11 +210,18 @@ def export_to_hf(source: str, output_dir: str, model_tag: Optional[str], step: O
|
|||
hf_model.model.load_state_dict(model.state_dict(), strict=True)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
# Tell transformers how to load custom code when trust_remote_code=True
|
||||
hf_model.config.auto_map = {
|
||||
"AutoConfig": "configuration_nanochat.NanoChatHFConfig",
|
||||
"AutoModelForCausalLM": "modeling_nanochat.NanoChatHFForCausalLM",
|
||||
}
|
||||
hf_model.config.architectures = ["NanoChatHFForCausalLM"]
|
||||
|
||||
hf_model.save_pretrained(output_dir, safe_serialization=False)
|
||||
# Best effort: drop tokenizer files alongside weights
|
||||
copy_tokenizer_files(output_dir)
|
||||
print(f"[to_hf] Exported {source} checkpoint to {output_dir}")
|
||||
|
||||
write_hf_code(output_dir)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Export nanochat checkpoint to HuggingFace format")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user