mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-17 22:38:32 +00:00
remove leftover mid references (#491)
This commit is contained in:
parent
b19b4f3e49
commit
72b9064f9d
|
|
@ -164,7 +164,6 @@ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=Non
|
|||
def load_model(source, *args, **kwargs):
|
||||
model_dir = {
|
||||
"base": "base_checkpoints",
|
||||
"mid": "mid_checkpoints",
|
||||
"sft": "chatsft_checkpoints",
|
||||
"rl": "chatrl_checkpoints",
|
||||
}[source]
|
||||
|
|
|
|||
|
|
@ -211,8 +211,6 @@ EXPECTED_FILES = [
|
|||
"base-model-training.md",
|
||||
"base-model-loss.md",
|
||||
"base-model-evaluation.md",
|
||||
"midtraining.md",
|
||||
"chat-evaluation-mid.md",
|
||||
"chat-sft.md",
|
||||
"chat-evaluation-sft.md",
|
||||
"chat-rl.md",
|
||||
|
|
@ -316,8 +314,6 @@ class Report:
|
|||
# extract the most important metrics from the sections
|
||||
if file_name == "base-model-evaluation.md":
|
||||
final_metrics["base"] = extract(section, "CORE")
|
||||
if file_name == "chat-evaluation-mid.md":
|
||||
final_metrics["mid"] = extract(section, chat_metrics)
|
||||
if file_name == "chat-evaluation-sft.md":
|
||||
final_metrics["sft"] = extract(section, chat_metrics)
|
||||
if file_name == "chat-evaluation-rl.md":
|
||||
|
|
@ -337,7 +333,7 @@ class Report:
|
|||
# Custom ordering: CORE first, ChatCORE last, rest in middle
|
||||
all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x))
|
||||
# Fixed column widths
|
||||
stages = ["base", "mid", "sft", "rl"]
|
||||
stages = ["base", "sft", "rl"]
|
||||
metric_width = 15
|
||||
value_width = 8
|
||||
# Write table header
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from nanochat.engine import Engine
|
|||
from nanochat.checkpoint_manager import load_model
|
||||
|
||||
parser = argparse.ArgumentParser(description='Chat with the model')
|
||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
|
||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
|
||||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
||||
|
|
|
|||
|
|
@ -183,7 +183,7 @@ if __name__ == "__main__":
|
|||
|
||||
# Parse command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|mid|rl")
|
||||
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|rl")
|
||||
parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.0)
|
||||
|
|
|
|||
|
|
@ -38,7 +38,6 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d
|
|||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
|
|
@ -77,7 +76,7 @@ use_dummy_wandb = args.run == "dummy" or not master_process
|
|||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=args.run, config=user_config)
|
||||
|
||||
# Init model and tokenizer
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
||||
model, tokenizer, meta = load_model("sft", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
||||
engine = Engine(model, tokenizer) # for sampling rollouts
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ MAX_MAX_TOKENS = 4096
|
|||
|
||||
parser = argparse.ArgumentParser(description='NanoChat Web Server')
|
||||
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
|
||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
|
||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
|
||||
parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ LLM because it has to learn how every token (a little semantic chunk/atom)
|
|||
maps to the sequence of individual characters that make it up. Larger models
|
||||
learn this eventually on their own, but if we want this capability to exist
|
||||
in smaller models, we have to actively encourage it by over-representing it
|
||||
in the training data. Midtraining is a good place to do this.
|
||||
in the training data. SFT is a good place to do this.
|
||||
|
||||
To preview a few example conversations, run:
|
||||
python -m tasks.spellingbee
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user