mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
Merge karpathy/cpu-mps-dev , adding the ability to run on CPU, on MPS, or on CUDA, with autodetect. Gnarly PR, nonzero chance I broke something.
add cpu|mps support
This commit is contained in:
commit
33e8a27f91
|
|
@ -95,7 +95,7 @@ And a bit more about computing environments that will run nanochat:
|
|||
|
||||
## Running on CPU / MPS
|
||||
|
||||
If you'd like to tinker with nanochat on your Macbook or a CPU machine, there is a work in progress [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) up here. If you're on Macbook, use `--device_type=mps` when running `base_train.py`. See the PR and its diff for more. You're not going to get too far without GPU nodes, but at least you'll be able to run the code and maybe train a very tiny LLM with some patience.
|
||||
nanochat cn be run on CPU or on MPS (if you're on Macbook), and will automatically try to detect what device is best to run on. You're not going to get too far without GPUs, but at least you'll be able to run the code paths and maybe train a tiny LLM with some patience. For an example of how to make all the run commands much smaller (feel free to tune!), you can refer to [dev/runcpu.sh](dev/runcpu.sh) file. You'll see that I'm essentially restricting all scripts to train smaller models, to run for shorter number of iterations, etc. This functionality is new, slightly gnarly (touched a lot of code), and was merged in this [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) on Oct 21, 2025.
|
||||
|
||||
## Customization
|
||||
|
||||
|
|
|
|||
84
dev/runcpu.sh
Normal file
84
dev/runcpu.sh
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
|
||||
# Run as:
|
||||
# bash dev/cpu_demo_run.sh
|
||||
|
||||
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
|
||||
# Think of this run as educational/fun demo, not something you should expect to work well.
|
||||
# This is also why I hide this script away in dev/
|
||||
|
||||
# all the setup stuff
|
||||
export OMP_NUM_THREADS=1
|
||||
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync
|
||||
source .venv/bin/activate
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
||||
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
||||
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
|
||||
unzip -q eval_bundle.zip
|
||||
rm eval_bundle.zip
|
||||
mv eval_bundle $NANOCHAT_BASE_DIR
|
||||
fi
|
||||
|
||||
# wipe the report
|
||||
python -m nanochat.report reset
|
||||
|
||||
# train tokenizer on ~1B characters
|
||||
python -m nanochat.dataset -n 4
|
||||
python -m scripts.tok_train --max_chars=1000000000
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# train a very small 4 layer model on the CPU
|
||||
# each optimization step processes a single sequence of 1024 tokens
|
||||
# we only run 50 steps of optimization (bump this to get better results)
|
||||
python -m scripts.base_train \
|
||||
--depth=4 \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--total_batch_size=1024 \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=4096 \
|
||||
--core_metric_every=50 \
|
||||
--core_metric_max_per_task=12 \
|
||||
--sample_every=50 \
|
||||
--num_iterations=50
|
||||
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
|
||||
python -m scripts.base_eval --max-per-task=5
|
||||
|
||||
# midtraining
|
||||
python -m scripts.mid_train \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=4096 \
|
||||
--total_batch_size=1024 \
|
||||
--num_iterations=100
|
||||
# eval results will be terrible, this is just to execute the code paths.
|
||||
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
|
||||
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
|
||||
|
||||
# SFT
|
||||
python -m scripts.chat_sft \
|
||||
--device_batch_size=1 \
|
||||
--target_examples_per_step=4 \
|
||||
--num_iterations=100 \
|
||||
--eval_steps=4 \
|
||||
--eval_metrics_max_problems=16
|
||||
|
||||
# Chat CLI
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
||||
# Chat Web
|
||||
# python -m scripts.chat_web
|
||||
|
||||
python -m nanochat.report generate
|
||||
|
|
@ -89,32 +89,46 @@ def get_dist_info():
|
|||
else:
|
||||
return False, 0, 0, 1
|
||||
|
||||
def compute_init():
|
||||
def autodetect_device_type():
|
||||
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
||||
if torch.cuda.is_available():
|
||||
device_type = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
device_type = "mps"
|
||||
else:
|
||||
device_type = "cpu"
|
||||
print0(f"Autodetected device type: {device_type}")
|
||||
return device_type
|
||||
|
||||
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||
"""Basic initialization that we keep doing over and over, so make common."""
|
||||
|
||||
# CUDA is currently required
|
||||
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
|
||||
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
|
||||
if device_type == "cuda":
|
||||
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
|
||||
if device_type == "mps":
|
||||
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
||||
|
||||
# Reproducibility
|
||||
torch.manual_seed(42)
|
||||
if device_type == "cuda":
|
||||
torch.cuda.manual_seed(42)
|
||||
# skipping full reproducibility for now, possibly investigate slowdown later
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
|
||||
# Precision
|
||||
if device_type == "cuda":
|
||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
if ddp:
|
||||
if ddp and device_type == "cuda":
|
||||
device = torch.device("cuda", ddp_local_rank)
|
||||
torch.cuda.set_device(device) # make "cuda" default to this device
|
||||
dist.init_process_group(backend="nccl", device_id=device)
|
||||
dist.barrier()
|
||||
else:
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(device_type) # mps|cpu
|
||||
|
||||
if ddp_rank == 0:
|
||||
logger.info(f"Distributed world size: {ddp_world_size}")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from nanochat.common import get_dist_info
|
|||
from nanochat.dataset import parquets_iter_batched
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
|
||||
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
|
||||
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
|
||||
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
|
|
@ -43,6 +43,6 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
|||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
# Reshape to 2D and move to GPU async
|
||||
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
|
||||
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
yield inputs, targets
|
||||
|
|
|
|||
|
|
@ -146,12 +146,11 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
|
|||
with caution.
|
||||
"""
|
||||
|
||||
if maximum_memory_bytes is not None:
|
||||
if platform.uname().system != "Darwin":
|
||||
# These resource limit calls seem to fail on macOS (Darwin), skip?
|
||||
import resource
|
||||
|
||||
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
if not platform.uname().system == "Darwin":
|
||||
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
|
||||
faulthandler.disable()
|
||||
|
|
@ -225,6 +224,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
|
|||
rmtree = shutil.rmtree
|
||||
rmdir = os.rmdir
|
||||
chdir = os.chdir
|
||||
unlink = os.unlink
|
||||
|
||||
# Disable functionalities that can make destructive changes to the test.
|
||||
reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
|
||||
|
|
@ -282,6 +282,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
|
|||
shutil.rmtree = rmtree
|
||||
os.rmdir = rmdir
|
||||
os.chdir = chdir
|
||||
os.unlink = unlink
|
||||
|
||||
|
||||
def execute_code(
|
||||
|
|
|
|||
|
|
@ -169,8 +169,6 @@ class GPT(nn.Module):
|
|||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
|
||||
def init_weights(self):
|
||||
self.apply(self._init_weights)
|
||||
|
|
@ -184,6 +182,9 @@ class GPT(nn.Module):
|
|||
head_dim = self.config.n_embd // self.config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.cos, self.sin = cos, sin
|
||||
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
|||
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
||||
loss2d = loss2d.view(-1) # flatten
|
||||
y = y.view(-1) # flatten
|
||||
if (y < 0).any():
|
||||
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
|
||||
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
|
||||
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
|
||||
valid = y >= 0
|
||||
|
|
|
|||
|
|
@ -283,6 +283,10 @@ class Report:
|
|||
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
|
||||
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
|
||||
bloat_data = bloat_data.group(1) if bloat_data else ""
|
||||
else:
|
||||
start_time = None # will cause us to not write the total wall clock time
|
||||
bloat_data = "[bloat data missing]"
|
||||
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
|
||||
# process all the individual sections
|
||||
for file_name in EXPECTED_FILES:
|
||||
section_file = os.path.join(report_dir, file_name)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ dependencies = [
|
|||
"numpy==1.26.4",
|
||||
"psutil>=7.1.0",
|
||||
"regex>=2025.9.1",
|
||||
"setuptools>=80.9.0",
|
||||
"tiktoken>=0.11.0",
|
||||
"tokenizers>=0.22.0",
|
||||
"torch>=2.8.0",
|
||||
|
|
@ -22,17 +23,6 @@ dependencies = [
|
|||
requires = ["maturin>=1.7,<2.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
# target torch to cuda 12.8
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cu128" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
|
||||
[tool.maturin]
|
||||
module-name = "rustbpe"
|
||||
bindings = "pyo3"
|
||||
|
|
@ -53,3 +43,20 @@ testpaths = ["tests"]
|
|||
python_files = ["test_*.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
|
||||
# target torch to cuda 12.8
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
|
||||
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
#!/bin/bash
|
||||
|
||||
# The $1000 tier of nanochat
|
||||
# Designed to run end-to-end for $1000/24 ~= 41.6 hours on an 8XH100 node
|
||||
# A bit sparser on comments, see speedrun.sh for more detail
|
||||
|
|
|
|||
|
|
@ -15,11 +15,12 @@ import time
|
|||
import json
|
||||
import random
|
||||
import yaml
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type
|
||||
from nanochat.tokenizer import HuggingFaceTokenizer
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.core_eval import evaluate_task
|
||||
|
|
@ -118,16 +119,21 @@ def load_hf_model(hf_path: str, device):
|
|||
|
||||
# -----------------------------------------------------------------------------
|
||||
def main():
|
||||
assert len(sys.argv) in [1, 2], "Usage: python base_eval.py [hf_path]"
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
|
||||
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
|
||||
args = parser.parse_args()
|
||||
|
||||
# distributed / precision setup
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
device_type = autodetect_device_type()
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# Load model and tokenizer from command line or from file system
|
||||
if len(sys.argv) >= 2:
|
||||
if args.hf_path is not None:
|
||||
# atm assume that if a path is given, it's a huggingface model path
|
||||
hf_path = sys.argv[1]
|
||||
hf_path = args.hf_path
|
||||
print0(f"Loading huggingface model from: {hf_path}")
|
||||
model, tokenizer = load_hf_model(hf_path, device)
|
||||
model_name = hf_path # just for logging
|
||||
|
|
@ -140,7 +146,7 @@ def main():
|
|||
|
||||
# Evaluate the model
|
||||
with autocast_ctx:
|
||||
out = evaluate_model(model, tokenizer, device)
|
||||
out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
|
||||
|
||||
# Write out the results to a csv file
|
||||
core_metric = None
|
||||
|
|
|
|||
|
|
@ -7,9 +7,10 @@ Example run as:
|
|||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||
"""
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
import torch
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import compute_init, print0, compute_cleanup
|
||||
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -20,15 +21,15 @@ device_batch_size = 32
|
|||
split_tokens = 20*524288 # number of tokens to evaluate per split
|
||||
model_tag = None # optional model tag for the output directory name
|
||||
model_step = None # optional model step for the output directory name
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
|
||||
# Load the base model and the tokenizer
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
|
||||
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
||||
|
||||
# Set up the precision we'll run with
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# Evaluate the loss on each split
|
||||
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
||||
|
|
@ -37,7 +38,7 @@ steps = split_tokens // tokens_per_step
|
|||
token_bytes = get_token_bytes(device=device)
|
||||
bpb_results = {}
|
||||
for split_name in ["train", "val"]:
|
||||
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name)
|
||||
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
|
||||
with autocast_ctx:
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
print0(f"{split_name} bpb: {bpb:.4f}")
|
||||
|
|
|
|||
|
|
@ -6,17 +6,22 @@ python base_train.py
|
|||
or distributed as:
|
||||
|
||||
torchrun --nproc_per_node=8 base_train.py
|
||||
|
||||
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
||||
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -27,6 +32,8 @@ print_banner()
|
|||
# -----------------------------------------------------------------------------
|
||||
# User settings
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
# Runtime
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
|
||||
# Model architecture
|
||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
max_seq_len = 2048 # max context length
|
||||
|
|
@ -45,7 +52,7 @@ grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
|||
# Evaluation
|
||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
||||
core_metric_every = 2000 # every how many steps to evaluate the core metric
|
||||
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
|
||||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||
sample_every = 2000 # every how many steps to sample from the model
|
||||
# Output
|
||||
|
|
@ -57,9 +64,12 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
|
|||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
@ -96,7 +106,7 @@ model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_la
|
|||
with torch.device("meta"):
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
model = GPT(model_config)
|
||||
model.to_empty(device="cuda")
|
||||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
||||
|
|
@ -133,8 +143,8 @@ adamw_optimizer, muon_optimizer = optimizers
|
|||
# Initialize the DataLoaders for train/val
|
||||
base_dir = get_base_dir()
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val")
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
|
||||
x, y = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -193,7 +203,8 @@ for step in range(num_iterations + 1):
|
|||
|
||||
# once in a while: estimate the CORE metric (all ranks participate)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
if last_step or (step > 0 and step % core_metric_every == 0):
|
||||
results = {}
|
||||
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
|
||||
model.eval()
|
||||
with autocast_ctx:
|
||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
||||
|
|
@ -252,7 +263,7 @@ for step in range(num_iterations + 1):
|
|||
# -------------------------------------------------------------------------
|
||||
# single training step
|
||||
# evaluate the gradient
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
|
|
@ -275,7 +286,7 @@ for step in range(num_iterations + 1):
|
|||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
|
@ -304,7 +315,7 @@ for step in range(num_iterations + 1):
|
|||
})
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
|
|
@ -326,11 +337,11 @@ get_report().log(section="Base model training", data=[
|
|||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
"Final validation bpb": val_bpb,
|
||||
"CORE metric estimate": results["core_metric"],
|
||||
"CORE metric estimate": results.get("core_metric", None),
|
||||
"MFU %": f"{mfu:.2f}%",
|
||||
"Total training flops": f"{flops_so_far:e}",
|
||||
"Total training time": f"{total_training_time/60:.2f}m",
|
||||
"Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB",
|
||||
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
|
||||
}
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ python -m scripts.chat_cli -i mid
|
|||
"""
|
||||
import argparse
|
||||
import torch
|
||||
from nanochat.common import compute_init
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from contextlib import nullcontext
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
|
||||
|
|
@ -17,11 +18,16 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
|||
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
args = parser.parse_args()
|
||||
|
||||
# Init the model and tokenizer
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
|
||||
# Special tokens for the chat state machine
|
||||
|
|
|
|||
|
|
@ -10,11 +10,12 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
|
|||
|
||||
import argparse
|
||||
from functools import partial
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0
|
||||
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
|
||||
|
|
@ -191,11 +192,13 @@ if __name__ == "__main__":
|
|||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
args = parser.parse_args()
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
engine = Engine(model, tokenizer)
|
||||
|
|
|
|||
|
|
@ -15,8 +15,9 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|||
import wandb
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb
|
||||
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.engine import Engine
|
||||
|
|
@ -36,11 +37,12 @@ source = "mid" # base|mid , which checkpoint to load the model from (base model
|
|||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
# compute/precision
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 4 # max to avoid OOM
|
||||
# optimization
|
||||
num_epochs = 1
|
||||
max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations)
|
||||
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
|
||||
target_examples_per_step = 32
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
|
|
@ -51,6 +53,7 @@ init_lr_frac = 0.02
|
|||
eval_every = 100
|
||||
eval_steps = 100
|
||||
eval_metrics_every = 200
|
||||
eval_metrics_max_problems = 1024
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
|
|
@ -58,10 +61,11 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
|
|||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
@ -128,10 +132,10 @@ assert target_examples_per_step % examples_per_step == 0, "Target examples per s
|
|||
grad_accum_steps = target_examples_per_step // examples_per_step
|
||||
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
|
||||
|
||||
if num_iterations == -1:
|
||||
# derive num_iterations from num_epochs and the size of the dataset
|
||||
assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
|
||||
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
|
||||
if max_iterations >= 0 and num_iterations > max_iterations:
|
||||
print0(f"Number of iterations is too high: {num_iterations}, capping to {max_iterations}")
|
||||
num_iterations = max_iterations
|
||||
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
|
||||
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
|
||||
|
||||
|
|
@ -191,8 +195,8 @@ for step in range(num_iterations):
|
|||
metrics = {}
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
||||
print0(f"Step {step:05d} | {metrics_str}")
|
||||
wandb_run.log({
|
||||
|
|
|
|||
|
|
@ -44,8 +44,8 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Optional, AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from nanochat.common import compute_init
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
|
||||
|
|
@ -69,6 +69,8 @@ parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default m
|
|||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -80,7 +82,9 @@ logging.basicConfig(
|
|||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
|
||||
@dataclass
|
||||
class Worker:
|
||||
|
|
@ -95,21 +99,33 @@ class WorkerPool:
|
|||
"""Pool of workers, each with a model replica on a different GPU."""
|
||||
|
||||
def __init__(self, num_gpus: Optional[int] = None):
|
||||
self.num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
|
||||
if num_gpus is None:
|
||||
if device_type == "cuda":
|
||||
num_gpus = torch.cuda.device_count()
|
||||
else:
|
||||
num_gpus = 1 # e.g. cpu|mps
|
||||
self.num_gpus = num_gpus
|
||||
self.workers: List[Worker] = []
|
||||
self.available_workers: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
|
||||
"""Load model on each GPU."""
|
||||
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
|
||||
if self.num_gpus > 1:
|
||||
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
|
||||
|
||||
for gpu_id in range(self.num_gpus):
|
||||
|
||||
if device_type == "cuda":
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
print(f"Loading model on GPU {gpu_id}...")
|
||||
else:
|
||||
device = torch.device(device_type) # e.g. cpu|mps
|
||||
print(f"Loading model on {device_type}...")
|
||||
|
||||
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||
engine = Engine(model, tokenizer)
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
worker = Worker(
|
||||
gpu_id=gpu_id,
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|||
import time
|
||||
import wandb
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -31,9 +31,11 @@ from tasks.customjson import CustomJSON
|
|||
|
||||
# -----------------------------------------------------------------------------
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
max_seq_len = 2048
|
||||
device_batch_size = 32
|
||||
unembedding_lr = 0.004
|
||||
|
|
@ -41,7 +43,7 @@ embedding_lr = 0.2
|
|||
matrix_lr = 0.02
|
||||
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
||||
weight_decay = 0.0
|
||||
eval_every = 150
|
||||
eval_every = 150 # -1 = disable
|
||||
eval_tokens = 20*524288
|
||||
total_batch_size = 524288
|
||||
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
|
||||
|
|
@ -51,10 +53,12 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
|
|||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
@ -117,6 +121,7 @@ def mid_data_generator(split):
|
|||
token_buffer = deque()
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||
it = 0 # iteration counter
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding
|
||||
while len(token_buffer) < needed_tokens:
|
||||
|
|
@ -128,6 +133,10 @@ def mid_data_generator(split):
|
|||
cursor -= dataset_size # wrap around for another epoch
|
||||
if split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Stopping condition to respect num_iterations, if given
|
||||
it += 1
|
||||
if num_iterations > 0 and it >= num_iterations:
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Build up inputs/targets and yield
|
||||
for i in range(needed_tokens):
|
||||
scratch[i] = token_buffer.popleft()
|
||||
|
|
@ -136,6 +145,9 @@ def mid_data_generator(split):
|
|||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
if split == "train":
|
||||
if num_iterations > 0:
|
||||
approx_progress = it / num_iterations # calculate progress from the max number of iterations
|
||||
else:
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
yield inputs, targets
|
||||
|
||||
|
|
@ -172,7 +184,7 @@ while True:
|
|||
last_step = bool(last_step_tensor.item())
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if last_step or step % eval_every == 0:
|
||||
if eval_every > 0 and (last_step or step % eval_every == 0):
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
||||
|
|
@ -219,7 +231,7 @@ while True:
|
|||
# -------------------------------------------------------------------------
|
||||
# single training step
|
||||
# evaluate the gradient
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
|
|
@ -240,7 +252,7 @@ while True:
|
|||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
|
@ -272,7 +284,7 @@ while True:
|
|||
})
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
|
|
|
|||
1
uv.lock
1
uv.lock
|
|
@ -2002,3 +2002,4 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" },
|
||||
]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user