mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-22 11:24:17 +00:00
debug to_hf
This commit is contained in:
parent
1dc944f734
commit
9b05e7c625
6
.gitignore
vendored
6
.gitignore
vendored
|
|
@ -18,4 +18,8 @@ hf-export/**/__pycache__/
|
|||
agent.md
|
||||
lm_eval_sample_*/
|
||||
*.pt
|
||||
hf-export/**/*
|
||||
hf-export/**/*
|
||||
lm_eval_pretrain/
|
||||
benchmark.md
|
||||
d*.png
|
||||
loss*.png
|
||||
|
|
@ -14,7 +14,7 @@ source .venv/bin/activate
|
|||
```bash
|
||||
# export latest base checkpoint to hf-export/moe_std (gpt2 tokenizer)
|
||||
uv run python -m nanochat.to_hf --source base --model-tag d20 --step 49000 --output hf-export/moe_std --tokenizer gpt2
|
||||
uv run python -m nanochat.to_hf --source base --model-tag d00 --output hf-export/moe_legacy --tokenizer gpt2
|
||||
uv run python -m nanochat.to_hf --source base --model-tag d24_e32_min_lr1e-05_max_lr0.0001 --output hf-export/d24_e32_min_lr1e-05_max_lr0.0001 --tokenizer cache
|
||||
# export latest SFT checkpoint (chat model, rustbpe tokenizer)
|
||||
uv run python -m nanochat.to_hf --source sft --output hf-export/moe_sft --tokenizer cache
|
||||
```
|
||||
|
|
@ -45,11 +45,11 @@ HF_ALLOW_CODE_EVAL=1 uv run lm-eval run --confirm_run_unsafe_code --model hf \
|
|||
--output_path lm_eval_sample_commonsense > sft_lr8_commonsense.log 2>&1
|
||||
|
||||
HF_ALLOW_CODE_EVAL=1 uv run lm-eval run --confirm_run_unsafe_code --model hf \
|
||||
--model_args pretrained=hf-export/moe_sft_lr0.9,trust_remote_code=True,tokenizer=hf-export/moe_sft_lr0.9,max_length=1024 \
|
||||
--model_args pretrained=/thullms/limh23/hf-export/{tag},trust_remote_code=True,tokenizer=/thullms/limh23/hf-export/{tag},max_length=1024 \
|
||||
--tasks hellaswag,boolq,piqa,winogrande,arc_easy,arc_challenge,mmlu \
|
||||
--batch_size 1 \
|
||||
--log_samples \
|
||||
--output_path lm_eval_sample_commonsense > moe_sft_lr0.9_all.log 2>&1
|
||||
--output_path lm_eval_pretrain > benchmark_pretrain/moe_{tag}_all.log 2>&1
|
||||
|
||||
# arc_easy,arc_challenge,mmlu
|
||||
HF_ALLOW_CODE_EVAL=1 uv run lm-eval run --confirm_run_unsafe_code --model hf \
|
||||
|
|
|
|||
|
|
@ -585,7 +585,7 @@ class GPT(nn.Module):
|
|||
return optimizer
|
||||
|
||||
|
||||
def forward(self, idx, targets=None):
|
||||
def forward(self, idx, targets=None, return_full_logits: bool = False):
|
||||
device = idx.device
|
||||
b, t = idx.size()
|
||||
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
||||
|
|
@ -599,22 +599,32 @@ class GPT(nn.Module):
|
|||
x = block(x)
|
||||
x = self.transformer.ln_f(x)
|
||||
|
||||
if targets is not None:
|
||||
# if we are given some desired targets also calculate the loss
|
||||
if targets is not None or return_full_logits:
|
||||
# full logits are needed for eval tooling; loss is optional
|
||||
logits = self.lm_head(x)
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
if targets is not None:
|
||||
# if we are given some desired targets also calculate the loss
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
|
||||
# add the auxiliary load balancing loss and router z loss to the main loss
|
||||
if self.config.n_exp > 1 and self.config.use_aux_loss:
|
||||
loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss()
|
||||
MANAGER.reset_aux_loss()
|
||||
if self.config.n_exp > 1 and self.config.use_router_z_loss:
|
||||
loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss()
|
||||
MANAGER.reset_router_z_loss()
|
||||
# add the auxiliary load balancing loss and router z loss to the main loss
|
||||
if self.config.n_exp > 1 and self.config.use_aux_loss:
|
||||
loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss()
|
||||
MANAGER.reset_aux_loss()
|
||||
if self.config.n_exp > 1 and self.config.use_router_z_loss:
|
||||
loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss()
|
||||
MANAGER.reset_router_z_loss()
|
||||
else:
|
||||
loss = None
|
||||
if self.config.n_exp > 1:
|
||||
MANAGER.reset_aux_loss()
|
||||
MANAGER.reset_router_z_loss()
|
||||
else:
|
||||
# inference-time mini-optimization: only forward the lm_head on the very last position
|
||||
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
||||
loss = None
|
||||
if self.config.n_exp > 1:
|
||||
MANAGER.reset_aux_loss()
|
||||
MANAGER.reset_router_z_loss()
|
||||
|
||||
return logits, loss
|
||||
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ HF_CONFIG_FIELDS = {
|
|||
"dropout",
|
||||
"bias",
|
||||
"n_exp",
|
||||
"top_k",
|
||||
"router_top_k",
|
||||
"use_aux_loss",
|
||||
"use_router_z_loss",
|
||||
"use_noisy_top_k",
|
||||
|
|
@ -320,7 +320,7 @@ class NanoChatMoEHFConfig(PretrainedConfig):
|
|||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
n_exp: int = 8,
|
||||
top_k: int = 2,
|
||||
router_top_k: int = 2,
|
||||
use_aux_loss: bool = True,
|
||||
use_router_z_loss: bool = True,
|
||||
use_noisy_top_k: bool = False,
|
||||
|
|
@ -335,6 +335,10 @@ class NanoChatMoEHFConfig(PretrainedConfig):
|
|||
router_use_full_prec: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if "top_k" in kwargs and "router_top_k" not in kwargs:
|
||||
kwargs["router_top_k"] = kwargs.pop("top_k")
|
||||
if "router_top_k" in kwargs:
|
||||
router_top_k = kwargs.pop("router_top_k")
|
||||
kwargs.setdefault("tie_word_embeddings", False)
|
||||
super().__init__(**kwargs)
|
||||
self.block_size = block_size
|
||||
|
|
@ -345,7 +349,7 @@ class NanoChatMoEHFConfig(PretrainedConfig):
|
|||
self.dropout = dropout
|
||||
self.bias = bias
|
||||
self.n_exp = n_exp
|
||||
self.top_k = top_k
|
||||
self.router_top_k = router_top_k
|
||||
self.use_aux_loss = use_aux_loss
|
||||
self.use_router_z_loss = use_router_z_loss
|
||||
self.use_noisy_top_k = use_noisy_top_k
|
||||
|
|
@ -382,7 +386,7 @@ class NanoChatMoEHFForCausalLM(PreTrainedModel, GenerationMixin):
|
|||
dropout=config.dropout,
|
||||
bias=config.bias,
|
||||
n_exp=config.n_exp,
|
||||
top_k=config.top_k,
|
||||
top_k=config.router_top_k,
|
||||
use_aux_loss=config.use_aux_loss,
|
||||
use_router_z_loss=config.use_router_z_loss,
|
||||
use_noisy_top_k=config.use_noisy_top_k,
|
||||
|
|
@ -421,10 +425,10 @@ class NanoChatMoEHFForCausalLM(PreTrainedModel, GenerationMixin):
|
|||
if input_ids is None:
|
||||
raise ValueError("input_ids must be provided")
|
||||
if labels is None:
|
||||
logits, _ = self.model(input_ids)
|
||||
logits, _ = self.model(input_ids, return_full_logits=True)
|
||||
loss = None
|
||||
else:
|
||||
logits, loss = self.model(input_ids, targets=labels)
|
||||
logits, loss = self.model(input_ids, targets=labels, return_full_logits=True)
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
|
|
@ -459,7 +463,7 @@ class NanoChatMoEHFConfig(PretrainedConfig):
|
|||
dropout=0.0,
|
||||
bias=False,
|
||||
n_exp=8,
|
||||
top_k=2,
|
||||
router_top_k=2,
|
||||
use_aux_loss=True,
|
||||
use_router_z_loss=True,
|
||||
use_noisy_top_k=False,
|
||||
|
|
@ -474,6 +478,10 @@ class NanoChatMoEHFConfig(PretrainedConfig):
|
|||
router_use_full_prec=True,
|
||||
**kwargs,
|
||||
):
|
||||
if "top_k" in kwargs and "router_top_k" not in kwargs:
|
||||
kwargs["router_top_k"] = kwargs.pop("top_k")
|
||||
if "router_top_k" in kwargs:
|
||||
router_top_k = kwargs.pop("router_top_k")
|
||||
kwargs.setdefault("tie_word_embeddings", False)
|
||||
super().__init__(**kwargs)
|
||||
self.block_size = block_size
|
||||
|
|
@ -484,7 +492,7 @@ class NanoChatMoEHFConfig(PretrainedConfig):
|
|||
self.dropout = dropout
|
||||
self.bias = bias
|
||||
self.n_exp = n_exp
|
||||
self.top_k = top_k
|
||||
self.router_top_k = router_top_k
|
||||
self.use_aux_loss = use_aux_loss
|
||||
self.use_router_z_loss = use_router_z_loss
|
||||
self.use_noisy_top_k = use_noisy_top_k
|
||||
|
|
@ -532,7 +540,7 @@ class NanoChatMoEHFForCausalLM(PreTrainedModel, GenerationMixin):
|
|||
dropout=config.dropout,
|
||||
bias=config.bias,
|
||||
n_exp=config.n_exp,
|
||||
top_k=config.top_k,
|
||||
top_k=config.router_top_k,
|
||||
use_aux_loss=config.use_aux_loss,
|
||||
use_router_z_loss=config.use_router_z_loss,
|
||||
use_noisy_top_k=config.use_noisy_top_k,
|
||||
|
|
@ -563,7 +571,7 @@ class NanoChatMoEHFForCausalLM(PreTrainedModel, GenerationMixin):
|
|||
def forward(self, input_ids=None, attention_mask=None, labels=None, past_key_values=None, **_):
|
||||
if input_ids is None:
|
||||
raise ValueError("input_ids must be provided")
|
||||
logits, loss = self.model(input_ids, targets=labels)
|
||||
logits, loss = self.model(input_ids, targets=labels, return_full_logits=True)
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
|
|
@ -987,7 +995,7 @@ class GPT(nn.Module):
|
|||
))
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
||||
def forward(self, idx, targets=None):
|
||||
def forward(self, idx, targets=None, return_full_logits=False):
|
||||
B, T = idx.size()
|
||||
assert T <= self.config.block_size
|
||||
pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
|
||||
|
|
@ -996,18 +1004,27 @@ class GPT(nn.Module):
|
|||
for block in self.transformer.h:
|
||||
x = block(x)
|
||||
x = self.transformer.ln_f(x)
|
||||
if targets is not None:
|
||||
if targets is not None or return_full_logits:
|
||||
logits = self.lm_head(x)
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
if self.config.n_exp > 1 and self.config.use_aux_loss:
|
||||
loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss()
|
||||
MANAGER.reset_aux_loss()
|
||||
if self.config.n_exp > 1 and self.config.use_router_z_loss:
|
||||
loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss()
|
||||
MANAGER.reset_router_z_loss()
|
||||
if targets is not None:
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
if self.config.n_exp > 1 and self.config.use_aux_loss:
|
||||
loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss()
|
||||
MANAGER.reset_aux_loss()
|
||||
if self.config.n_exp > 1 and self.config.use_router_z_loss:
|
||||
loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss()
|
||||
MANAGER.reset_router_z_loss()
|
||||
else:
|
||||
loss = None
|
||||
if self.config.n_exp > 1:
|
||||
MANAGER.reset_aux_loss()
|
||||
MANAGER.reset_router_z_loss()
|
||||
else:
|
||||
logits = self.lm_head(x[:, [-1], :])
|
||||
loss = None
|
||||
if self.config.n_exp > 1:
|
||||
MANAGER.reset_aux_loss()
|
||||
MANAGER.reset_router_z_loss()
|
||||
return logits, loss
|
||||
"""
|
||||
|
||||
|
|
@ -1033,6 +1050,10 @@ def export_to_hf(
|
|||
tokenizer = load_export_tokenizer(tokenizer_mode)
|
||||
model, meta, cfg_kwargs = load_moe_checkpoint(source, model_tag, step, device, tokenizer)
|
||||
hf_kwargs = {k: v for k, v in cfg_kwargs.items() if k in HF_CONFIG_FIELDS}
|
||||
if "router_top_k" in cfg_kwargs:
|
||||
hf_kwargs["router_top_k"] = cfg_kwargs["router_top_k"]
|
||||
elif "top_k" in cfg_kwargs:
|
||||
hf_kwargs["router_top_k"] = cfg_kwargs["top_k"]
|
||||
hf_config = NanoChatMoEHFConfig(**hf_kwargs)
|
||||
hf_model = NanoChatMoEHFForCausalLM(hf_config)
|
||||
hf_model.model.load_state_dict(model.state_dict(), strict=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user