mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
many small tweaks. base, eval, core work now i think
This commit is contained in:
parent
786119d593
commit
df600b6ed5
|
|
@ -93,9 +93,10 @@ def autodetect_device_type():
|
||||||
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device_type = "cuda"
|
device_type = "cuda"
|
||||||
if torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
device_type = "mps"
|
device_type = "mps"
|
||||||
device_type = "cpu"
|
else:
|
||||||
|
device_type = "cpu"
|
||||||
print0(f"Autodetected device type: {device_type}")
|
print0(f"Autodetected device type: {device_type}")
|
||||||
return device_type
|
return device_type
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -283,6 +283,10 @@ class Report:
|
||||||
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
|
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
|
||||||
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
|
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
|
||||||
bloat_data = bloat_data.group(1) if bloat_data else ""
|
bloat_data = bloat_data.group(1) if bloat_data else ""
|
||||||
|
else:
|
||||||
|
start_time = None # will cause us to not write the total wall clock time
|
||||||
|
bloat_data = "[bloat data missing]"
|
||||||
|
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
|
||||||
# process all the individual sections
|
# process all the individual sections
|
||||||
for file_name in EXPECTED_FILES:
|
for file_name in EXPECTED_FILES:
|
||||||
section_file = os.path.join(report_dir, file_name)
|
section_file = os.path.join(report_dir, file_name)
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ import time
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import yaml
|
import yaml
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -118,18 +119,21 @@ def load_hf_model(hf_path: str, device):
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
def main():
|
def main():
|
||||||
assert len(sys.argv) in [1, 2], "Usage: python base_eval.py [hf_path]"
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
|
||||||
|
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
# distributed / precision setup
|
# distributed / precision setup
|
||||||
device_type = autodetect_device_type()
|
device_type = autodetect_device_type()
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
dtype = torch.bfloat16 if device_type == "cuda" else torch.float32 # use fp32 on CPU|MPS
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=dtype)
|
|
||||||
|
|
||||||
# Load model and tokenizer from command line or from file system
|
# Load model and tokenizer from command line or from file system
|
||||||
if len(sys.argv) >= 2:
|
if args.hf_path is not None:
|
||||||
# atm assume that if a path is given, it's a huggingface model path
|
# atm assume that if a path is given, it's a huggingface model path
|
||||||
hf_path = sys.argv[1]
|
hf_path = args.hf_path
|
||||||
print0(f"Loading huggingface model from: {hf_path}")
|
print0(f"Loading huggingface model from: {hf_path}")
|
||||||
model, tokenizer = load_hf_model(hf_path, device)
|
model, tokenizer = load_hf_model(hf_path, device)
|
||||||
model_name = hf_path # just for logging
|
model_name = hf_path # just for logging
|
||||||
|
|
@ -142,7 +146,7 @@ def main():
|
||||||
|
|
||||||
# Evaluate the model
|
# Evaluate the model
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
out = evaluate_model(model, tokenizer, device)
|
out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
|
||||||
|
|
||||||
# Write out the results to a csv file
|
# Write out the results to a csv file
|
||||||
core_metric = None
|
core_metric = None
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ Example run as:
|
||||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
from contextlib import nullcontext
|
||||||
import torch
|
import torch
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
|
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
|
||||||
|
|
@ -26,10 +27,9 @@ exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from
|
||||||
# Load the base model and the tokenizer
|
# Load the base model and the tokenizer
|
||||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
dtype = torch.bfloat16 if device_type == "cuda" else torch.float32 # use fp32 on CPU|MPS
|
|
||||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
|
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
|
||||||
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
||||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=dtype)
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||||
|
|
||||||
# Evaluate the loss on each split
|
# Evaluate the loss on each split
|
||||||
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
||||||
|
|
@ -38,7 +38,7 @@ steps = split_tokens // tokens_per_step
|
||||||
token_bytes = get_token_bytes(device=device)
|
token_bytes = get_token_bytes(device=device)
|
||||||
bpb_results = {}
|
bpb_results = {}
|
||||||
for split_name in ["train", "val"]:
|
for split_name in ["train", "val"]:
|
||||||
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name)
|
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||||
print0(f"{split_name} bpb: {bpb:.4f}")
|
print0(f"{split_name} bpb: {bpb:.4f}")
|
||||||
|
|
|
||||||
|
|
@ -7,13 +7,15 @@ or distributed as:
|
||||||
|
|
||||||
torchrun --nproc_per_node=8 base_train.py
|
torchrun --nproc_per_node=8 base_train.py
|
||||||
|
|
||||||
python -m scripts.base_train --device_type=cpu --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_max_per_task=8 --total_batch_size=512 --num_iterations=500
|
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
||||||
If you have a Macbook, you're better off using device_type=mps instead of cpu
|
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
import time
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
@ -50,7 +52,7 @@ grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
||||||
# Evaluation
|
# Evaluation
|
||||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||||
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
||||||
core_metric_every = 2000 # every how many steps to evaluate the core metric
|
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
|
||||||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||||
sample_every = 2000 # every how many steps to sample from the model
|
sample_every = 2000 # every how many steps to sample from the model
|
||||||
# Output
|
# Output
|
||||||
|
|
@ -65,8 +67,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
|
||||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||||
dtype = torch.bfloat16 if device_type == "cuda" else torch.float32 # use fp32 on CPU|MPS
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=dtype)
|
|
||||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||||
|
|
||||||
|
|
@ -202,7 +203,8 @@ for step in range(num_iterations + 1):
|
||||||
|
|
||||||
# once in a while: estimate the CORE metric (all ranks participate)
|
# once in a while: estimate the CORE metric (all ranks participate)
|
||||||
# use the original uncompiled model because the inputs keep changing shape
|
# use the original uncompiled model because the inputs keep changing shape
|
||||||
if last_step or (step > 0 and step % core_metric_every == 0):
|
results = {}
|
||||||
|
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
|
||||||
model.eval()
|
model.eval()
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
||||||
|
|
@ -228,7 +230,7 @@ for step in range(num_iterations + 1):
|
||||||
"My favorite color is",
|
"My favorite color is",
|
||||||
"If 5*x + 3 = 13, then x is",
|
"If 5*x + 3 = 13, then x is",
|
||||||
]
|
]
|
||||||
engine = Engine(model, tokenizer)
|
engine = Engine(orig_model, tokenizer)
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
|
|
@ -335,7 +337,7 @@ get_report().log(section="Base model training", data=[
|
||||||
{ # stats about training outcomes
|
{ # stats about training outcomes
|
||||||
"Minimum validation bpb": min_val_bpb,
|
"Minimum validation bpb": min_val_bpb,
|
||||||
"Final validation bpb": val_bpb,
|
"Final validation bpb": val_bpb,
|
||||||
"CORE metric estimate": results["core_metric"],
|
"CORE metric estimate": results.get("core_metric", None),
|
||||||
"MFU %": f"{mfu:.2f}%",
|
"MFU %": f"{mfu:.2f}%",
|
||||||
"Total training flops": f"{flops_so_far:e}",
|
"Total training flops": f"{flops_so_far:e}",
|
||||||
"Total training time": f"{total_training_time/60:.2f}m",
|
"Total training time": f"{total_training_time/60:.2f}m",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user