mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
merge and resolve conflict
This commit is contained in:
commit
5bdc99abfb
84
dev/runcpu.sh
Normal file
84
dev/runcpu.sh
Normal file
|
|
@ -0,0 +1,84 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
|
||||||
|
# Run as:
|
||||||
|
# bash dev/cpu_demo_run.sh
|
||||||
|
|
||||||
|
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
|
||||||
|
# Think of this run as educational/fun demo, not something you should expect to work well.
|
||||||
|
# This is also why I hide this script away in dev/
|
||||||
|
|
||||||
|
# all the setup stuff
|
||||||
|
export OMP_NUM_THREADS=1
|
||||||
|
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||||
|
mkdir -p $NANOCHAT_BASE_DIR
|
||||||
|
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
[ -d ".venv" ] || uv venv
|
||||||
|
uv sync
|
||||||
|
source .venv/bin/activate
|
||||||
|
if [ -z "$WANDB_RUN" ]; then
|
||||||
|
WANDB_RUN=dummy
|
||||||
|
fi
|
||||||
|
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||||
|
source "$HOME/.cargo/env"
|
||||||
|
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||||
|
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
||||||
|
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
||||||
|
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
|
||||||
|
unzip -q eval_bundle.zip
|
||||||
|
rm eval_bundle.zip
|
||||||
|
mv eval_bundle $NANOCHAT_BASE_DIR
|
||||||
|
fi
|
||||||
|
|
||||||
|
# wipe the report
|
||||||
|
python -m nanochat.report reset
|
||||||
|
|
||||||
|
# train tokenizer on ~1B characters
|
||||||
|
python -m nanochat.dataset -n 4
|
||||||
|
python -m scripts.tok_train --max_chars=1000000000
|
||||||
|
python -m scripts.tok_eval
|
||||||
|
|
||||||
|
# train a very small 4 layer model on the CPU
|
||||||
|
# each optimization step processes a single sequence of 1024 tokens
|
||||||
|
# we only run 50 steps of optimization (bump this to get better results)
|
||||||
|
python -m scripts.base_train \
|
||||||
|
--depth=4 \
|
||||||
|
--max_seq_len=1024 \
|
||||||
|
--device_batch_size=1 \
|
||||||
|
--total_batch_size=1024 \
|
||||||
|
--eval_every=50 \
|
||||||
|
--eval_tokens=4096 \
|
||||||
|
--core_metric_every=50 \
|
||||||
|
--core_metric_max_per_task=12 \
|
||||||
|
--sample_every=50 \
|
||||||
|
--num_iterations=50
|
||||||
|
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
|
||||||
|
python -m scripts.base_eval --max-per-task=5
|
||||||
|
|
||||||
|
# midtraining
|
||||||
|
python -m scripts.mid_train \
|
||||||
|
--max_seq_len=1024 \
|
||||||
|
--device_batch_size=1 \
|
||||||
|
--eval_every=50 \
|
||||||
|
--eval_tokens=4096 \
|
||||||
|
--total_batch_size=1024 \
|
||||||
|
--num_iterations=100
|
||||||
|
# eval results will be terrible, this is just to execute the code paths.
|
||||||
|
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
|
||||||
|
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
|
||||||
|
|
||||||
|
# SFT
|
||||||
|
python -m scripts.chat_sft \
|
||||||
|
--device_batch_size=1 \
|
||||||
|
--target_examples_per_step=4 \
|
||||||
|
--num_iterations=100 \
|
||||||
|
--eval_steps=4 \
|
||||||
|
--eval_metrics_max_problems=16
|
||||||
|
|
||||||
|
# Chat CLI
|
||||||
|
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||||
|
|
||||||
|
# Chat Web
|
||||||
|
# python -m scripts.chat_web
|
||||||
|
|
||||||
|
python -m nanochat.report generate
|
||||||
|
|
@ -89,11 +89,25 @@ def get_dist_info():
|
||||||
else:
|
else:
|
||||||
return False, 0, 0, 1
|
return False, 0, 0, 1
|
||||||
|
|
||||||
def compute_init(device_type="cuda"): # cuda|cpu
|
def autodetect_device_type():
|
||||||
|
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device_type = "cuda"
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
device_type = "mps"
|
||||||
|
else:
|
||||||
|
device_type = "cpu"
|
||||||
|
print0(f"Autodetected device type: {device_type}")
|
||||||
|
return device_type
|
||||||
|
|
||||||
|
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||||
"""Basic initialization that we keep doing over and over, so make common."""
|
"""Basic initialization that we keep doing over and over, so make common."""
|
||||||
|
|
||||||
# CUDA is currently required
|
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
|
||||||
# assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
|
if device_type == "cuda":
|
||||||
|
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
|
||||||
|
if device_type == "mps":
|
||||||
|
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
|
|
@ -101,11 +115,10 @@ def compute_init(device_type="cuda"): # cuda|cpu
|
||||||
torch.cuda.manual_seed(42)
|
torch.cuda.manual_seed(42)
|
||||||
# skipping full reproducibility for now, possibly investigate slowdown later
|
# skipping full reproducibility for now, possibly investigate slowdown later
|
||||||
# torch.use_deterministic_algorithms(True)
|
# torch.use_deterministic_algorithms(True)
|
||||||
# torch.backends.cudnn.deterministic = True
|
|
||||||
# torch.backends.cudnn.benchmark = False
|
|
||||||
|
|
||||||
# Precision
|
# Precision
|
||||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
if device_type == "cuda":
|
||||||
|
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||||
|
|
||||||
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||||
|
|
@ -115,7 +128,7 @@ def compute_init(device_type="cuda"): # cuda|cpu
|
||||||
dist.init_process_group(backend="nccl", device_id=device)
|
dist.init_process_group(backend="nccl", device_id=device)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
else:
|
else:
|
||||||
device = torch.device(device_type) # cuda|cpu
|
device = torch.device(device_type) # mps|cpu
|
||||||
|
|
||||||
if ddp_rank == 0:
|
if ddp_rank == 0:
|
||||||
logger.info(f"Distributed world size: {ddp_world_size}")
|
logger.info(f"Distributed world size: {ddp_world_size}")
|
||||||
|
|
|
||||||
|
|
@ -146,13 +146,12 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
|
||||||
with caution.
|
with caution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if maximum_memory_bytes is not None:
|
if platform.uname().system != "Darwin":
|
||||||
|
# These resource limit calls seem to fail on macOS (Darwin), skip?
|
||||||
import resource
|
import resource
|
||||||
|
|
||||||
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
|
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
|
||||||
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
|
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
|
||||||
if not platform.uname().system == "Darwin":
|
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
||||||
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
|
||||||
|
|
||||||
faulthandler.disable()
|
faulthandler.disable()
|
||||||
|
|
||||||
|
|
@ -225,6 +224,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
|
||||||
rmtree = shutil.rmtree
|
rmtree = shutil.rmtree
|
||||||
rmdir = os.rmdir
|
rmdir = os.rmdir
|
||||||
chdir = os.chdir
|
chdir = os.chdir
|
||||||
|
unlink = os.unlink
|
||||||
|
|
||||||
# Disable functionalities that can make destructive changes to the test.
|
# Disable functionalities that can make destructive changes to the test.
|
||||||
reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
|
reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
|
||||||
|
|
@ -282,6 +282,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
|
||||||
shutil.rmtree = rmtree
|
shutil.rmtree = rmtree
|
||||||
os.rmdir = rmdir
|
os.rmdir = rmdir
|
||||||
os.chdir = chdir
|
os.chdir = chdir
|
||||||
|
os.unlink = unlink
|
||||||
|
|
||||||
|
|
||||||
def execute_code(
|
def execute_code(
|
||||||
|
|
|
||||||
|
|
@ -169,8 +169,6 @@ class GPT(nn.Module):
|
||||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||||
self.register_buffer("sin", sin, persistent=False)
|
self.register_buffer("sin", sin, persistent=False)
|
||||||
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
|
||||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
self.apply(self._init_weights)
|
self.apply(self._init_weights)
|
||||||
|
|
@ -184,6 +182,9 @@ class GPT(nn.Module):
|
||||||
head_dim = self.config.n_embd // self.config.n_head
|
head_dim = self.config.n_embd // self.config.n_head
|
||||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||||
self.cos, self.sin = cos, sin
|
self.cos, self.sin = cos, sin
|
||||||
|
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
||||||
|
if self.transformer.wte.weight.device.type == "cuda":
|
||||||
|
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
||||||
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
||||||
loss2d = loss2d.view(-1) # flatten
|
loss2d = loss2d.view(-1) # flatten
|
||||||
y = y.view(-1) # flatten
|
y = y.view(-1) # flatten
|
||||||
if (y < 0).any():
|
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
|
||||||
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
|
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
|
||||||
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
|
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
|
||||||
valid = y >= 0
|
valid = y >= 0
|
||||||
|
|
|
||||||
|
|
@ -283,6 +283,10 @@ class Report:
|
||||||
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
|
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
|
||||||
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
|
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
|
||||||
bloat_data = bloat_data.group(1) if bloat_data else ""
|
bloat_data = bloat_data.group(1) if bloat_data else ""
|
||||||
|
else:
|
||||||
|
start_time = None # will cause us to not write the total wall clock time
|
||||||
|
bloat_data = "[bloat data missing]"
|
||||||
|
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
|
||||||
# process all the individual sections
|
# process all the individual sections
|
||||||
for file_name in EXPECTED_FILES:
|
for file_name in EXPECTED_FILES:
|
||||||
section_file = os.path.join(report_dir, file_name)
|
section_file = os.path.join(report_dir, file_name)
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ dependencies = [
|
||||||
"numpy==1.26.4",
|
"numpy==1.26.4",
|
||||||
"psutil>=7.1.0",
|
"psutil>=7.1.0",
|
||||||
"regex>=2025.9.1",
|
"regex>=2025.9.1",
|
||||||
|
"setuptools>=80.9.0",
|
||||||
"tiktoken>=0.11.0",
|
"tiktoken>=0.11.0",
|
||||||
"tokenizers>=0.22.0",
|
"tokenizers>=0.22.0",
|
||||||
"torch>=2.8.0",
|
"torch>=2.8.0",
|
||||||
|
|
@ -22,17 +23,6 @@ dependencies = [
|
||||||
requires = ["maturin>=1.7,<2.0"]
|
requires = ["maturin>=1.7,<2.0"]
|
||||||
build-backend = "maturin"
|
build-backend = "maturin"
|
||||||
|
|
||||||
# target torch to cuda 12.8
|
|
||||||
[tool.uv.sources]
|
|
||||||
torch = [
|
|
||||||
{ index = "pytorch-cu128" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
|
||||||
name = "pytorch-cu128"
|
|
||||||
url = "https://download.pytorch.org/whl/cu128"
|
|
||||||
explicit = true
|
|
||||||
|
|
||||||
[tool.maturin]
|
[tool.maturin]
|
||||||
module-name = "rustbpe"
|
module-name = "rustbpe"
|
||||||
bindings = "pyo3"
|
bindings = "pyo3"
|
||||||
|
|
@ -53,3 +43,20 @@ testpaths = ["tests"]
|
||||||
python_files = ["test_*.py"]
|
python_files = ["test_*.py"]
|
||||||
python_classes = ["Test*"]
|
python_classes = ["Test*"]
|
||||||
python_functions = ["test_*"]
|
python_functions = ["test_*"]
|
||||||
|
|
||||||
|
# target torch to cuda 12.8
|
||||||
|
[tool.uv.sources]
|
||||||
|
torch = [
|
||||||
|
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
|
||||||
|
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cpu"
|
||||||
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
explicit = true
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cu128"
|
||||||
|
url = "https://download.pytorch.org/whl/cu128"
|
||||||
|
explicit = true
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
# The $1000 tier of nanochat
|
# The $1000 tier of nanochat
|
||||||
# Designed to run end-to-end for $1000/24 ~= 41.6 hours on an 8XH100 node
|
# Designed to run end-to-end for $1000/24 ~= 41.6 hours on an 8XH100 node
|
||||||
# A bit sparser on comments, see speedrun.sh for more detail
|
# A bit sparser on comments, see speedrun.sh for more detail
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,12 @@ import time
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import yaml
|
import yaml
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir
|
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type
|
||||||
from nanochat.tokenizer import HuggingFaceTokenizer
|
from nanochat.tokenizer import HuggingFaceTokenizer
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
from nanochat.core_eval import evaluate_task
|
from nanochat.core_eval import evaluate_task
|
||||||
|
|
@ -118,16 +119,21 @@ def load_hf_model(hf_path: str, device):
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
def main():
|
def main():
|
||||||
assert len(sys.argv) in [1, 2], "Usage: python base_eval.py [hf_path]"
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
|
||||||
|
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
# distributed / precision setup
|
# distributed / precision setup
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
device_type = autodetect_device_type()
|
||||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||||
|
|
||||||
# Load model and tokenizer from command line or from file system
|
# Load model and tokenizer from command line or from file system
|
||||||
if len(sys.argv) >= 2:
|
if args.hf_path is not None:
|
||||||
# atm assume that if a path is given, it's a huggingface model path
|
# atm assume that if a path is given, it's a huggingface model path
|
||||||
hf_path = sys.argv[1]
|
hf_path = args.hf_path
|
||||||
print0(f"Loading huggingface model from: {hf_path}")
|
print0(f"Loading huggingface model from: {hf_path}")
|
||||||
model, tokenizer = load_hf_model(hf_path, device)
|
model, tokenizer = load_hf_model(hf_path, device)
|
||||||
model_name = hf_path # just for logging
|
model_name = hf_path # just for logging
|
||||||
|
|
@ -140,7 +146,7 @@ def main():
|
||||||
|
|
||||||
# Evaluate the model
|
# Evaluate the model
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
out = evaluate_model(model, tokenizer, device)
|
out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
|
||||||
|
|
||||||
# Write out the results to a csv file
|
# Write out the results to a csv file
|
||||||
core_metric = None
|
core_metric = None
|
||||||
|
|
|
||||||
|
|
@ -7,9 +7,10 @@ Example run as:
|
||||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
from contextlib import nullcontext
|
||||||
import torch
|
import torch
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
from nanochat.common import compute_init, print0, compute_cleanup
|
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
|
||||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||||
from nanochat.tokenizer import get_token_bytes
|
from nanochat.tokenizer import get_token_bytes
|
||||||
from nanochat.loss_eval import evaluate_bpb
|
from nanochat.loss_eval import evaluate_bpb
|
||||||
|
|
@ -20,15 +21,15 @@ device_batch_size = 32
|
||||||
split_tokens = 20*524288 # number of tokens to evaluate per split
|
split_tokens = 20*524288 # number of tokens to evaluate per split
|
||||||
model_tag = None # optional model tag for the output directory name
|
model_tag = None # optional model tag for the output directory name
|
||||||
model_step = None # optional model step for the output directory name
|
model_step = None # optional model step for the output directory name
|
||||||
|
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||||
|
|
||||||
# Load the base model and the tokenizer
|
# Load the base model and the tokenizer
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||||
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
|
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
|
||||||
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
||||||
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||||
# Set up the precision we'll run with
|
|
||||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
# Evaluate the loss on each split
|
# Evaluate the loss on each split
|
||||||
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
||||||
|
|
@ -37,7 +38,7 @@ steps = split_tokens // tokens_per_step
|
||||||
token_bytes = get_token_bytes(device=device)
|
token_bytes = get_token_bytes(device=device)
|
||||||
bpb_results = {}
|
bpb_results = {}
|
||||||
for split_name in ["train", "val"]:
|
for split_name in ["train", "val"]:
|
||||||
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name)
|
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||||
print0(f"{split_name} bpb: {bpb:.4f}")
|
print0(f"{split_name} bpb: {bpb:.4f}")
|
||||||
|
|
|
||||||
|
|
@ -7,19 +7,21 @@ or distributed as:
|
||||||
|
|
||||||
torchrun --nproc_per_node=8 base_train.py
|
torchrun --nproc_per_node=8 base_train.py
|
||||||
|
|
||||||
If you just want to see it run on CPU (you won't get far but it should run), try something like:
|
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
||||||
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --device_type=cpu --eval_tokens=512 --total_batch_size=512 --num_iterations=1000
|
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
import time
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from nanochat.gpt import GPT, GPTConfig
|
from nanochat.gpt import GPT, GPTConfig
|
||||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir
|
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
|
||||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||||
from nanochat.checkpoint_manager import save_checkpoint
|
from nanochat.checkpoint_manager import save_checkpoint
|
||||||
from nanochat.loss_eval import evaluate_bpb
|
from nanochat.loss_eval import evaluate_bpb
|
||||||
|
|
@ -31,7 +33,7 @@ print_banner()
|
||||||
# User settings
|
# User settings
|
||||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||||
# Runtime
|
# Runtime
|
||||||
device_type = "cuda" # cuda|cpu
|
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
|
||||||
# Model architecture
|
# Model architecture
|
||||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||||
max_seq_len = 2048 # max context length
|
max_seq_len = 2048 # max context length
|
||||||
|
|
@ -50,7 +52,7 @@ grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
||||||
# Evaluation
|
# Evaluation
|
||||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||||
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
||||||
core_metric_every = 2000 # every how many steps to evaluate the core metric
|
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
|
||||||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||||
sample_every = 2000 # every how many steps to sample from the model
|
sample_every = 2000 # every how many steps to sample from the model
|
||||||
# Output
|
# Output
|
||||||
|
|
@ -62,9 +64,10 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
# Compute init
|
# Compute init
|
||||||
|
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||||
|
|
||||||
|
|
@ -200,7 +203,8 @@ for step in range(num_iterations + 1):
|
||||||
|
|
||||||
# once in a while: estimate the CORE metric (all ranks participate)
|
# once in a while: estimate the CORE metric (all ranks participate)
|
||||||
# use the original uncompiled model because the inputs keep changing shape
|
# use the original uncompiled model because the inputs keep changing shape
|
||||||
if last_step or (step > 0 and step % core_metric_every == 0):
|
results = {}
|
||||||
|
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
|
||||||
model.eval()
|
model.eval()
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
||||||
|
|
@ -333,7 +337,7 @@ get_report().log(section="Base model training", data=[
|
||||||
{ # stats about training outcomes
|
{ # stats about training outcomes
|
||||||
"Minimum validation bpb": min_val_bpb,
|
"Minimum validation bpb": min_val_bpb,
|
||||||
"Final validation bpb": val_bpb,
|
"Final validation bpb": val_bpb,
|
||||||
"CORE metric estimate": results["core_metric"],
|
"CORE metric estimate": results.get("core_metric", None),
|
||||||
"MFU %": f"{mfu:.2f}%",
|
"MFU %": f"{mfu:.2f}%",
|
||||||
"Total training flops": f"{flops_so_far:e}",
|
"Total training flops": f"{flops_so_far:e}",
|
||||||
"Total training time": f"{total_training_time/60:.2f}m",
|
"Total training time": f"{total_training_time/60:.2f}m",
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,8 @@ python -m scripts.chat_cli -i mid
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
from nanochat.common import compute_init
|
from nanochat.common import compute_init, autodetect_device_type
|
||||||
|
from contextlib import nullcontext
|
||||||
from nanochat.engine import Engine
|
from nanochat.engine import Engine
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
|
|
||||||
|
|
@ -17,11 +18,16 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||||
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
||||||
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
||||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
||||||
|
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||||
|
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Init the model and tokenizer
|
# Init the model and tokenizer
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
|
||||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||||
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
|
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||||
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||||
|
|
||||||
# Special tokens for the chat state machine
|
# Special tokens for the chat state machine
|
||||||
|
|
|
||||||
|
|
@ -10,11 +10,12 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0
|
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
from nanochat.engine import Engine
|
from nanochat.engine import Engine
|
||||||
|
|
||||||
|
|
@ -191,11 +192,13 @@ if __name__ == "__main__":
|
||||||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||||
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
|
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
|
||||||
|
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||||
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype)
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||||
|
|
||||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||||
engine = Engine(model, tokenizer)
|
engine = Engine(model, tokenizer)
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,9 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
import wandb
|
import wandb
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb
|
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
from nanochat.checkpoint_manager import save_checkpoint
|
from nanochat.checkpoint_manager import save_checkpoint
|
||||||
from nanochat.engine import Engine
|
from nanochat.engine import Engine
|
||||||
|
|
@ -36,11 +37,12 @@ source = "mid" # base|mid , which checkpoint to load the model from (base model
|
||||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||||
step = None # step to load the model from (base model or midtrained model)
|
step = None # step to load the model from (base model or midtrained model)
|
||||||
# compute/precision
|
# compute/precision
|
||||||
|
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||||
dtype = "bfloat16"
|
dtype = "bfloat16"
|
||||||
device_batch_size = 4 # max to avoid OOM
|
device_batch_size = 4 # max to avoid OOM
|
||||||
# optimization
|
# optimization
|
||||||
num_epochs = 1
|
num_epochs = 1
|
||||||
max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations)
|
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
|
||||||
target_examples_per_step = 32
|
target_examples_per_step = 32
|
||||||
unembedding_lr = 0.004
|
unembedding_lr = 0.004
|
||||||
embedding_lr = 0.2
|
embedding_lr = 0.2
|
||||||
|
|
@ -51,6 +53,7 @@ init_lr_frac = 0.02
|
||||||
eval_every = 100
|
eval_every = 100
|
||||||
eval_steps = 100
|
eval_steps = 100
|
||||||
eval_metrics_every = 200
|
eval_metrics_every = 200
|
||||||
|
eval_metrics_max_problems = 1024
|
||||||
# now allow CLI to override the settings via the configurator lol
|
# now allow CLI to override the settings via the configurator lol
|
||||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||||
|
|
@ -58,10 +61,11 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
# Compute init
|
# Compute init
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||||
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
master_process = ddp_rank == 0
|
master_process = ddp_rank == 0
|
||||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||||
|
|
||||||
# wandb logging init
|
# wandb logging init
|
||||||
use_dummy_wandb = run == "dummy" or not master_process
|
use_dummy_wandb = run == "dummy" or not master_process
|
||||||
|
|
@ -128,10 +132,10 @@ assert target_examples_per_step % examples_per_step == 0, "Target examples per s
|
||||||
grad_accum_steps = target_examples_per_step // examples_per_step
|
grad_accum_steps = target_examples_per_step // examples_per_step
|
||||||
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
|
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
|
||||||
|
|
||||||
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
|
if num_iterations == -1:
|
||||||
if max_iterations >= 0 and num_iterations > max_iterations:
|
# derive num_iterations from num_epochs and the size of the dataset
|
||||||
print0(f"Number of iterations is too high: {num_iterations}, capping to {max_iterations}")
|
assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
|
||||||
num_iterations = max_iterations
|
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
|
||||||
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
|
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
|
||||||
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
|
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
|
||||||
|
|
||||||
|
|
@ -191,8 +195,8 @@ for step in range(num_iterations):
|
||||||
metrics = {}
|
metrics = {}
|
||||||
with torch.no_grad(), autocast_ctx:
|
with torch.no_grad(), autocast_ctx:
|
||||||
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
||||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||||
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
||||||
print0(f"Step {step:05d} | {metrics_str}")
|
print0(f"Step {step:05d} | {metrics_str}")
|
||||||
wandb_run.log({
|
wandb_run.log({
|
||||||
|
|
|
||||||
|
|
@ -44,8 +44,8 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Optional, AsyncGenerator
|
from typing import List, Optional, AsyncGenerator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from contextlib import nullcontext
|
||||||
from nanochat.common import compute_init
|
from nanochat.common import compute_init, autodetect_device_type
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
from nanochat.engine import Engine
|
from nanochat.engine import Engine
|
||||||
|
|
||||||
|
|
@ -69,6 +69,8 @@ parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default m
|
||||||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||||
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
||||||
|
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||||
|
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
@ -80,7 +82,9 @@ logging.basicConfig(
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||||
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
|
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Worker:
|
class Worker:
|
||||||
|
|
@ -95,21 +99,33 @@ class WorkerPool:
|
||||||
"""Pool of workers, each with a model replica on a different GPU."""
|
"""Pool of workers, each with a model replica on a different GPU."""
|
||||||
|
|
||||||
def __init__(self, num_gpus: Optional[int] = None):
|
def __init__(self, num_gpus: Optional[int] = None):
|
||||||
self.num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
|
if num_gpus is None:
|
||||||
|
if device_type == "cuda":
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
else:
|
||||||
|
num_gpus = 1 # e.g. cpu|mps
|
||||||
|
self.num_gpus = num_gpus
|
||||||
self.workers: List[Worker] = []
|
self.workers: List[Worker] = []
|
||||||
self.available_workers: asyncio.Queue = asyncio.Queue()
|
self.available_workers: asyncio.Queue = asyncio.Queue()
|
||||||
|
|
||||||
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
|
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
|
||||||
"""Load model on each GPU."""
|
"""Load model on each GPU."""
|
||||||
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
|
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
|
||||||
|
if self.num_gpus > 1:
|
||||||
|
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
|
||||||
|
|
||||||
for gpu_id in range(self.num_gpus):
|
for gpu_id in range(self.num_gpus):
|
||||||
device = torch.device(f"cuda:{gpu_id}")
|
|
||||||
print(f"Loading model on GPU {gpu_id}...")
|
if device_type == "cuda":
|
||||||
|
device = torch.device(f"cuda:{gpu_id}")
|
||||||
|
print(f"Loading model on GPU {gpu_id}...")
|
||||||
|
else:
|
||||||
|
device = torch.device(device_type) # e.g. cpu|mps
|
||||||
|
print(f"Loading model on {device_type}...")
|
||||||
|
|
||||||
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||||
engine = Engine(model, tokenizer)
|
engine = Engine(model, tokenizer)
|
||||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||||
|
|
||||||
worker = Worker(
|
worker = Worker(
|
||||||
gpu_id=gpu_id,
|
gpu_id=gpu_id,
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,8 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
import time
|
import time
|
||||||
import wandb
|
import wandb
|
||||||
import torch
|
import torch
|
||||||
|
from contextlib import nullcontext
|
||||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir
|
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
|
||||||
from nanochat.tokenizer import get_token_bytes
|
from nanochat.tokenizer import get_token_bytes
|
||||||
from nanochat.checkpoint_manager import save_checkpoint
|
from nanochat.checkpoint_manager import save_checkpoint
|
||||||
from nanochat.loss_eval import evaluate_bpb
|
from nanochat.loss_eval import evaluate_bpb
|
||||||
|
|
@ -31,9 +31,11 @@ from tasks.customjson import CustomJSON
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||||
|
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||||
step = None # step to load the model from (base model or midtrained model)
|
step = None # step to load the model from (base model or midtrained model)
|
||||||
dtype = "bfloat16"
|
dtype = "bfloat16"
|
||||||
|
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||||
max_seq_len = 2048
|
max_seq_len = 2048
|
||||||
device_batch_size = 32
|
device_batch_size = 32
|
||||||
unembedding_lr = 0.004
|
unembedding_lr = 0.004
|
||||||
|
|
@ -41,7 +43,7 @@ embedding_lr = 0.2
|
||||||
matrix_lr = 0.02
|
matrix_lr = 0.02
|
||||||
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
||||||
weight_decay = 0.0
|
weight_decay = 0.0
|
||||||
eval_every = 150
|
eval_every = 150 # -1 = disable
|
||||||
eval_tokens = 20*524288
|
eval_tokens = 20*524288
|
||||||
total_batch_size = 524288
|
total_batch_size = 524288
|
||||||
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
|
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
|
||||||
|
|
@ -51,10 +53,12 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
# Compute init
|
# Compute init
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||||
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
master_process = ddp_rank == 0
|
master_process = ddp_rank == 0
|
||||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||||
|
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||||
|
|
||||||
# wandb logging init
|
# wandb logging init
|
||||||
use_dummy_wandb = run == "dummy" or not master_process
|
use_dummy_wandb = run == "dummy" or not master_process
|
||||||
|
|
@ -117,6 +121,7 @@ def mid_data_generator(split):
|
||||||
token_buffer = deque()
|
token_buffer = deque()
|
||||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
||||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||||
|
it = 0 # iteration counter
|
||||||
while True:
|
while True:
|
||||||
# Accumulate enough tokens for one iteration before yielding
|
# Accumulate enough tokens for one iteration before yielding
|
||||||
while len(token_buffer) < needed_tokens:
|
while len(token_buffer) < needed_tokens:
|
||||||
|
|
@ -128,6 +133,10 @@ def mid_data_generator(split):
|
||||||
cursor -= dataset_size # wrap around for another epoch
|
cursor -= dataset_size # wrap around for another epoch
|
||||||
if split == "train":
|
if split == "train":
|
||||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||||
|
# Stopping condition to respect num_iterations, if given
|
||||||
|
it += 1
|
||||||
|
if num_iterations > 0 and it >= num_iterations:
|
||||||
|
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||||
# Build up inputs/targets and yield
|
# Build up inputs/targets and yield
|
||||||
for i in range(needed_tokens):
|
for i in range(needed_tokens):
|
||||||
scratch[i] = token_buffer.popleft()
|
scratch[i] = token_buffer.popleft()
|
||||||
|
|
@ -136,7 +145,10 @@ def mid_data_generator(split):
|
||||||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||||
if split == "train":
|
if split == "train":
|
||||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
if num_iterations > 0:
|
||||||
|
approx_progress = it / num_iterations # calculate progress from the max number of iterations
|
||||||
|
else:
|
||||||
|
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||||
yield inputs, targets
|
yield inputs, targets
|
||||||
|
|
||||||
train_loader = mid_data_generator("train")
|
train_loader = mid_data_generator("train")
|
||||||
|
|
@ -172,7 +184,7 @@ while True:
|
||||||
last_step = bool(last_step_tensor.item())
|
last_step = bool(last_step_tensor.item())
|
||||||
|
|
||||||
# once in a while: evaluate the val bpb (all ranks participate)
|
# once in a while: evaluate the val bpb (all ranks participate)
|
||||||
if last_step or step % eval_every == 0:
|
if eval_every > 0 and (last_step or step % eval_every == 0):
|
||||||
model.eval()
|
model.eval()
|
||||||
val_loader = build_val_loader()
|
val_loader = build_val_loader()
|
||||||
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
||||||
|
|
@ -219,7 +231,7 @@ while True:
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
# single training step
|
# single training step
|
||||||
# evaluate the gradient
|
# evaluate the gradient
|
||||||
torch.cuda.synchronize()
|
synchronize()
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
for micro_step in range(grad_accum_steps):
|
for micro_step in range(grad_accum_steps):
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
|
|
@ -240,7 +252,7 @@ while True:
|
||||||
for opt in optimizers:
|
for opt in optimizers:
|
||||||
opt.step()
|
opt.step()
|
||||||
model.zero_grad(set_to_none=True)
|
model.zero_grad(set_to_none=True)
|
||||||
torch.cuda.synchronize()
|
synchronize()
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
dt = t1 - t0
|
dt = t1 - t0
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
|
|
@ -272,7 +284,7 @@ while True:
|
||||||
})
|
})
|
||||||
|
|
||||||
# print a few more stats
|
# print a few more stats
|
||||||
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
|
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||||
|
|
||||||
|
|
|
||||||
1
uv.lock
1
uv.lock
|
|
@ -2002,3 +2002,4 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" },
|
{ url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" },
|
{ url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user