mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Support training with only CPU,
coded with help of Devin DeepWiki planner, GPT-5 Codex execution with VS Code agent mode. https://deepwiki.com/search/suggest-how-to-modify-the-code_80cebbfc-0ad0-4b92-addd-2b4210fa9f04
This commit is contained in:
parent
e27d2da3d9
commit
de6eac26be
|
|
@ -5,6 +5,8 @@ Common utilities for nanochat.
|
|||
import os
|
||||
import re
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
|
@ -45,6 +47,8 @@ def setup_default_logging():
|
|||
setup_default_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DDP_STATE: Optional[bool] = None
|
||||
|
||||
def get_base_dir():
|
||||
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
|
||||
if os.environ.get("NANOCHAT_BASE_DIR"):
|
||||
|
|
@ -77,6 +81,8 @@ def print_banner():
|
|||
|
||||
def is_ddp():
|
||||
# TODO is there a proper way
|
||||
if _DDP_STATE is not None:
|
||||
return _DDP_STATE
|
||||
return int(os.environ.get('RANK', -1)) != -1
|
||||
|
||||
def get_dist_info():
|
||||
|
|
@ -92,12 +98,13 @@ def get_dist_info():
|
|||
def compute_init():
|
||||
"""Basic initialization that we keep doing over and over, so make common."""
|
||||
|
||||
# CUDA is currently required
|
||||
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
|
||||
global _DDP_STATE
|
||||
|
||||
# Reproducibility
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(42)
|
||||
|
||||
# skipping full reproducibility for now, possibly investigate slowdown later
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
|
|
@ -108,13 +115,23 @@ def compute_init():
|
|||
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
cuda_available = torch.cuda.is_available()
|
||||
|
||||
if ddp and not cuda_available:
|
||||
logger.warning("CUDA not available; disabling distributed setup and running on CPU.")
|
||||
ddp = False
|
||||
ddp_rank = ddp_local_rank = 0
|
||||
ddp_world_size = 1
|
||||
|
||||
if ddp:
|
||||
device = torch.device("cuda", ddp_local_rank)
|
||||
torch.cuda.set_device(device) # make "cuda" default to this device
|
||||
dist.init_process_group(backend="nccl", device_id=device)
|
||||
dist.barrier()
|
||||
else:
|
||||
device = torch.device("cuda")
|
||||
device = torch.device("cuda") if cuda_available else torch.device("cpu")
|
||||
|
||||
_DDP_STATE = ddp
|
||||
|
||||
if ddp_rank == 0:
|
||||
logger.info(f"Distributed world size: {ddp_world_size}")
|
||||
|
|
@ -123,7 +140,7 @@ def compute_init():
|
|||
|
||||
def compute_cleanup():
|
||||
"""Companion function to compute_init, to clean things up before script exit"""
|
||||
if is_ddp():
|
||||
if is_ddp() and dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
class DummyWandb:
|
||||
|
|
@ -134,3 +151,37 @@ class DummyWandb:
|
|||
pass
|
||||
def finish(self):
|
||||
pass
|
||||
|
||||
|
||||
def resolve_autocast_dtype(
|
||||
device_type: str,
|
||||
requested_dtype: Optional[Union[str, torch.dtype]] = None
|
||||
) -> torch.dtype:
|
||||
"""Return a safe autocast dtype for the given device."""
|
||||
|
||||
if isinstance(requested_dtype, torch.dtype):
|
||||
dtype = requested_dtype
|
||||
elif isinstance(requested_dtype, str):
|
||||
key = requested_dtype.lower()
|
||||
if key in {"bfloat16", "bf16"}:
|
||||
dtype = torch.bfloat16
|
||||
elif key in {"float16", "fp16", "half"}:
|
||||
dtype = torch.float16
|
||||
elif key in {"float32", "fp32"}:
|
||||
dtype = torch.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype string: {requested_dtype}")
|
||||
elif requested_dtype is None:
|
||||
dtype = torch.bfloat16 if device_type == "cuda" else torch.float32
|
||||
else:
|
||||
raise TypeError(f"Unsupported dtype specifier type: {type(requested_dtype)!r}")
|
||||
|
||||
if device_type != "cuda" and dtype != torch.float32:
|
||||
logger.warning(
|
||||
"Falling back to float32 autocast on %s (requested %s unsupported)",
|
||||
device_type,
|
||||
dtype,
|
||||
)
|
||||
dtype = torch.float32
|
||||
|
||||
return dtype
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from nanochat.common import get_dist_info
|
|||
from nanochat.dataset import parquets_iter_batched
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
|
||||
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
|
||||
def tokenizing_distributed_data_loader(B, T, split, device=None, tokenizer_threads=4, tokenizer_batch_size=128):
|
||||
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
|
|
@ -16,7 +16,14 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
|||
bos_token = tokenizer.get_bos_token_id()
|
||||
# scratch buffer holds the tokens for one iteration
|
||||
token_buffer = deque() # we stream tokens on the right and pop from the left
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
||||
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
non_blocking = device.type == "cuda"
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=non_blocking)
|
||||
|
||||
# infinite iterator over document batches
|
||||
def document_batches():
|
||||
|
|
@ -29,6 +36,7 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
|||
batches = document_batches()
|
||||
|
||||
batch_index = 0
|
||||
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding.
|
||||
while len(token_buffer) < needed_tokens:
|
||||
|
|
@ -43,7 +51,7 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
|||
# Create the inputs/targets as 1D tensors
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
# Reshape to 2D and move to GPU async
|
||||
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
|
||||
# Reshape to 2D and move to target device
|
||||
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=non_blocking)
|
||||
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=non_blocking)
|
||||
yield inputs, targets
|
||||
|
|
|
|||
|
|
@ -169,8 +169,9 @@ class GPT(nn.Module):
|
|||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
# Cast embeddings to bf16 only when the runtime supports it well (e.g. CUDA)
|
||||
if torch.cuda.is_available():
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
|
||||
def init_weights(self):
|
||||
self.apply(self._init_weights)
|
||||
|
|
@ -210,7 +211,8 @@ class GPT(nn.Module):
|
|||
# calculate the rotation frequencies at each (time, channel) pair
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
cos, sin = freqs.cos(), freqs.sin()
|
||||
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
||||
rotary_dtype = self.transformer.wte.weight.dtype
|
||||
cos, sin = cos.to(dtype=rotary_dtype), sin.to(dtype=rotary_dtype)
|
||||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||
return cos, sin
|
||||
|
||||
|
|
@ -262,7 +264,8 @@ class GPT(nn.Module):
|
|||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
||||
expected_dtype = self.transformer.wte.weight.dtype
|
||||
assert self.cos.dtype == expected_dtype, "Rotary embeddings dtype mismatch"
|
||||
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
||||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import yaml
|
|||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, resolve_autocast_dtype
|
||||
from nanochat.tokenizer import HuggingFaceTokenizer
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.core_eval import evaluate_task
|
||||
|
|
@ -122,7 +122,9 @@ def main():
|
|||
|
||||
# distributed / precision setup
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
device_type = device.type
|
||||
autocast_dtype = resolve_autocast_dtype(device_type)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=autocast_dtype)
|
||||
|
||||
# Load model and tokenizer from command line or from file system
|
||||
if len(sys.argv) >= 2:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
|||
import os
|
||||
import torch
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import compute_init, print0, compute_cleanup
|
||||
from nanochat.common import compute_init, print0, compute_cleanup, resolve_autocast_dtype
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -28,7 +28,9 @@ model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=mode
|
|||
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
||||
|
||||
# Set up the precision we'll run with
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
device_type = device.type
|
||||
autocast_dtype = resolve_autocast_dtype(device_type)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=autocast_dtype)
|
||||
|
||||
# Evaluate the loss on each split
|
||||
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
||||
|
|
@ -37,7 +39,7 @@ steps = split_tokens // tokens_per_step
|
|||
token_bytes = get_token_bytes(device=device)
|
||||
bpb_results = {}
|
||||
for split_name in ["train", "val"]:
|
||||
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name)
|
||||
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
|
||||
with autocast_ctx:
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
print0(f"{split_name} bpb: {bpb:.4f}")
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@
|
|||
Train model. Run as:
|
||||
|
||||
python base_train.py
|
||||
|
||||
"MFU %": f"{mfu:.2f}%" if mfu is not None else "N/A",
|
||||
or distributed as:
|
||||
|
||||
torchrun --nproc_per_node=8 base_train.py
|
||||
"Peak memory usage": f"{peak_memory_mib:.2f}MiB" if peak_memory_mib is not None else "N/A",
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -16,7 +16,7 @@ import torch
|
|||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, resolve_autocast_dtype
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -58,8 +58,11 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
|
|||
|
||||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = device.type
|
||||
is_cuda = device_type == "cuda"
|
||||
autocast_dtype = resolve_autocast_dtype(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=autocast_dtype)
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
@ -96,7 +99,7 @@ model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_la
|
|||
with torch.device("meta"):
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
model = GPT(model_config)
|
||||
model.to_empty(device="cuda")
|
||||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
||||
|
|
@ -133,8 +136,8 @@ adamw_optimizer, muon_optimizer = optimizers
|
|||
# Initialize the DataLoaders for train/val
|
||||
base_dir = get_base_dir()
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val")
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
|
||||
x, y = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -165,16 +168,22 @@ def get_muon_momentum(it):
|
|||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
min_val_bpb = float("inf")
|
||||
last_val_bpb = float("nan")
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
ema_beta = 0.9 # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
mfu = None
|
||||
core_results = {"core_metric": float("nan"), "centered_results": {}}
|
||||
eval_enabled = eval_every > 0 and eval_tokens > 0
|
||||
core_metric_enabled = core_metric_every > 0 and core_metric_max_per_task > 0
|
||||
sample_enabled = sample_every > 0
|
||||
# note that we run +1 steps only so that we can eval and save at the end
|
||||
for step in range(num_iterations + 1):
|
||||
last_step = step == num_iterations
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if last_step or step % eval_every == 0:
|
||||
if eval_enabled and (last_step or step % eval_every == 0):
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
||||
|
|
@ -183,6 +192,7 @@ for step in range(num_iterations + 1):
|
|||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
min_val_bpb = val_bpb
|
||||
last_val_bpb = val_bpb
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
|
|
@ -193,22 +203,22 @@ for step in range(num_iterations + 1):
|
|||
|
||||
# once in a while: estimate the CORE metric (all ranks participate)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
if last_step or (step > 0 and step % core_metric_every == 0):
|
||||
if core_metric_enabled and (last_step or (step > 0 and step % core_metric_every == 0)):
|
||||
model.eval()
|
||||
with autocast_ctx:
|
||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
||||
core_results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
||||
print0(f"Step {step:05d} | CORE metric: {core_results['core_metric']:.4f}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"core_metric": results["core_metric"],
|
||||
"centered_results": results["centered_results"],
|
||||
"core_metric": core_results["core_metric"],
|
||||
"centered_results": core_results["centered_results"],
|
||||
})
|
||||
model.train()
|
||||
|
||||
# once in a while: sample from the model (only on master process)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
if master_process and (last_step or (step > 0 and step % sample_every == 0)):
|
||||
if sample_enabled and master_process and (last_step or (step > 0 and step % sample_every == 0)):
|
||||
model.eval()
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
|
|
@ -252,7 +262,8 @@ for step in range(num_iterations + 1):
|
|||
# -------------------------------------------------------------------------
|
||||
# single training step
|
||||
# evaluate the gradient
|
||||
torch.cuda.synchronize()
|
||||
if is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
|
|
@ -260,7 +271,7 @@ for step in range(num_iterations + 1):
|
|||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
x, y = next(train_loader) # prefetch the next batch while the accelerator is busy with forward/backward
|
||||
# gradient clipping (TODO possibly expertiment with)
|
||||
if grad_clip > 0.0:
|
||||
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
|
||||
|
|
@ -275,7 +286,8 @@ for step in range(num_iterations + 1):
|
|||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
torch.cuda.synchronize()
|
||||
if is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
|
@ -286,13 +298,14 @@ for step in range(num_iterations + 1):
|
|||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
promised_flops_per_sec = 989e12 * ddp_world_size if is_cuda else None # bfloat16 H100 baseline
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec if promised_flops_per_sec else None
|
||||
mfu_display = f"{mfu:.2f}" if mfu is not None else "N/A"
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu_display} | total time: {total_training_time/60:.2f}m")
|
||||
if step % 100 == 0:
|
||||
wandb_run.log({
|
||||
log_payload = {
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
|
|
@ -300,11 +313,17 @@ for step in range(num_iterations + 1):
|
|||
"train/lrm": lrm,
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
})
|
||||
}
|
||||
if mfu is not None:
|
||||
log_payload["train/mfu"] = mfu
|
||||
wandb_run.log(log_payload)
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
|
||||
peak_memory_mib = (torch.cuda.max_memory_allocated() / 1024 / 1024) if is_cuda else None
|
||||
if peak_memory_mib is not None:
|
||||
print0(f"Peak memory usage: {peak_memory_mib:.2f}MiB")
|
||||
else:
|
||||
print0("Peak memory usage: N/A (CPU run)")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
|
|
@ -325,12 +344,12 @@ get_report().log(section="Base model training", data=[
|
|||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
"Final validation bpb": val_bpb,
|
||||
"CORE metric estimate": results["core_metric"],
|
||||
"MFU %": f"{mfu:.2f}%",
|
||||
"Final validation bpb": last_val_bpb,
|
||||
"CORE metric estimate": core_results["core_metric"],
|
||||
"MFU %": f"{mfu:.2f}%" if mfu is not None else "N/A",
|
||||
"Total training flops": f"{flops_so_far:e}",
|
||||
"Total training time": f"{total_training_time/60:.2f}m",
|
||||
"Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB",
|
||||
"Peak memory usage": f"{peak_memory_mib:.2f}MiB" if peak_memory_mib is not None else "N/A",
|
||||
}
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ python -m scripts.chat_cli -i mid
|
|||
"""
|
||||
import argparse
|
||||
import torch
|
||||
from nanochat.common import compute_init
|
||||
from nanochat.common import compute_init, resolve_autocast_dtype
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
|
||||
|
|
@ -21,7 +21,9 @@ args = parser.parse_args()
|
|||
|
||||
# Init the model and tokenizer
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
device_type = device.type
|
||||
autocast_dtype = resolve_autocast_dtype(device_type)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=autocast_dtype)
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
|
||||
# Special tokens for the chat state machine
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0
|
||||
from nanochat.common import resolve_autocast_dtype
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
|
||||
|
|
@ -194,8 +195,9 @@ if __name__ == "__main__":
|
|||
args = parser.parse_args()
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype)
|
||||
device_type = device.type
|
||||
autocast_dtype = resolve_autocast_dtype(device_type, args.dtype)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=autocast_dtype)
|
||||
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
engine = Engine(model, tokenizer)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import wandb
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, resolve_autocast_dtype
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model
|
||||
from nanochat.engine import Engine
|
||||
from tasks.gsm8k import GSM8K
|
||||
|
|
@ -55,9 +55,10 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
|
|||
|
||||
# Init compute/precision
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = device.type
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
autocast_dtype = resolve_autocast_dtype(device_type, dtype)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=autocast_dtype)
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import wandb
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb
|
||||
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, resolve_autocast_dtype
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.engine import Engine
|
||||
|
|
@ -61,9 +61,10 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
|
|||
|
||||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = device.type
|
||||
master_process = ddp_rank == 0
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
autocast_dtype = resolve_autocast_dtype(device_type, dtype)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=autocast_dtype)
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Optional, AsyncGenerator
|
||||
|
||||
from nanochat.common import compute_init
|
||||
from nanochat.common import compute_init, resolve_autocast_dtype
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
|
||||
|
|
@ -32,7 +32,9 @@ parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind th
|
|||
args = parser.parse_args()
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
device_type = device.type
|
||||
autocast_dtype = resolve_autocast_dtype(device_type)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=autocast_dtype)
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import time
|
|||
import wandb
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, resolve_autocast_dtype
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -51,9 +51,11 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
|
|||
|
||||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = device.type
|
||||
is_cuda = device_type == "cuda"
|
||||
master_process = ddp_rank == 0
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
autocast_dtype = resolve_autocast_dtype(device_type, dtype)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=autocast_dtype)
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
@ -129,8 +131,8 @@ def mid_data_generator(split):
|
|||
scratch[i] = token_buffer.popleft()
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=is_cuda)
|
||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=is_cuda)
|
||||
if split == "train":
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
yield inputs, targets
|
||||
|
|
@ -156,6 +158,7 @@ min_val_bpb = float("inf")
|
|||
smooth_train_loss = 0 # EMA of training loss
|
||||
ema_beta = 0.9 # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
mfu = None
|
||||
step = 0
|
||||
while True:
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
|
|
@ -214,7 +217,8 @@ while True:
|
|||
# -------------------------------------------------------------------------
|
||||
# single training step
|
||||
# evaluate the gradient
|
||||
torch.cuda.synchronize()
|
||||
if is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
|
|
@ -222,7 +226,7 @@ while True:
|
|||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
x, y = next(train_loader) # prefetch the next batch while the accelerator is busy with forward/backward
|
||||
progress = max(progress, approx_progress) # only increase progress monotonically
|
||||
# step the optimizers
|
||||
lrm = get_lr_multiplier(progress)
|
||||
|
|
@ -235,7 +239,8 @@ while True:
|
|||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
torch.cuda.synchronize()
|
||||
if is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
|
@ -249,13 +254,14 @@ while True:
|
|||
pct_done = 100 * progress
|
||||
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
promised_flops_per_sec = 989e12 * ddp_world_size if is_cuda else None # bfloat16 H100 baseline
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec if promised_flops_per_sec else None
|
||||
mfu_display = f"{mfu:.2f}" if mfu is not None else "N/A"
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu_display} | total time: {total_training_time/60:.2f}m")
|
||||
if step % 10 == 0:
|
||||
wandb_run.log({
|
||||
log_payload = {
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
|
|
@ -263,11 +269,17 @@ while True:
|
|||
"train/lrm": lrm,
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
})
|
||||
}
|
||||
if mfu is not None:
|
||||
log_payload["train/mfu"] = mfu
|
||||
wandb_run.log(log_payload)
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
|
||||
peak_memory_mib = (torch.cuda.max_memory_allocated() / 1024 / 1024) if is_cuda else None
|
||||
if peak_memory_mib is not None:
|
||||
print0(f"Peak memory usage: {peak_memory_mib:.2f}MiB")
|
||||
else:
|
||||
print0("Peak memory usage: N/A (CPU run)")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user