mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-15 18:49:10 +00:00
Allow local install and model loading
This commit is contained in:
parent
4610a838a1
commit
d6829284c4
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -12,3 +12,4 @@ eval_bundle/
|
|||
.claude
|
||||
CLAUDE.md
|
||||
wandb/
|
||||
*.egg-info/
|
||||
|
|
|
|||
|
|
@ -71,3 +71,6 @@ conflicts = [
|
|||
{ extra = "gpu" },
|
||||
],
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["nanochat"]
|
||||
|
|
|
|||
|
|
@ -134,13 +134,16 @@ def load_hf_model(hf_path: str, device):
|
|||
print0(f"Loading model from: {hf_path}")
|
||||
# Load the model
|
||||
from transformers import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained(hf_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(hf_path, trust_remote_code=True)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None
|
||||
model = ModelWrapper(model, max_seq_len=max_seq_len)
|
||||
# Load the tokenizer
|
||||
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
|
||||
if os.path.exists(hf_path):
|
||||
tokenizer = HuggingFaceTokenizer.from_directory(hf_path)
|
||||
else:
|
||||
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
|
||||
return model, tokenizer
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user