Support training with only CPU,

coded with help of Devin DeepWiki planner, GPT-5 Codex execution with VS Code agent mode.

https://deepwiki.com/search/suggest-how-to-modify-the-code_80cebbfc-0ad0-4b92-addd-2b4210fa9f04
This commit is contained in:
Luke Stanley 2025-10-14 00:39:03 +00:00
parent e27d2da3d9
commit de6eac26be
12 changed files with 179 additions and 74 deletions

View File

@ -5,6 +5,8 @@ Common utilities for nanochat.
import os
import re
import logging
from typing import Optional, Union
import torch
import torch.distributed as dist
@ -45,6 +47,8 @@ def setup_default_logging():
setup_default_logging()
logger = logging.getLogger(__name__)
_DDP_STATE: Optional[bool] = None
def get_base_dir():
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
if os.environ.get("NANOCHAT_BASE_DIR"):
@ -77,6 +81,8 @@ def print_banner():
def is_ddp():
# TODO is there a proper way
if _DDP_STATE is not None:
return _DDP_STATE
return int(os.environ.get('RANK', -1)) != -1
def get_dist_info():
@ -92,12 +98,13 @@ def get_dist_info():
def compute_init():
"""Basic initialization that we keep doing over and over, so make common."""
# CUDA is currently required
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
global _DDP_STATE
# Reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)
# skipping full reproducibility for now, possibly investigate slowdown later
# torch.use_deterministic_algorithms(True)
# torch.backends.cudnn.deterministic = True
@ -108,13 +115,23 @@ def compute_init():
# Distributed setup: Distributed Data Parallel (DDP), optional
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
cuda_available = torch.cuda.is_available()
if ddp and not cuda_available:
logger.warning("CUDA not available; disabling distributed setup and running on CPU.")
ddp = False
ddp_rank = ddp_local_rank = 0
ddp_world_size = 1
if ddp:
device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
else:
device = torch.device("cuda")
device = torch.device("cuda") if cuda_available else torch.device("cpu")
_DDP_STATE = ddp
if ddp_rank == 0:
logger.info(f"Distributed world size: {ddp_world_size}")
@ -123,7 +140,7 @@ def compute_init():
def compute_cleanup():
"""Companion function to compute_init, to clean things up before script exit"""
if is_ddp():
if is_ddp() and dist.is_initialized():
dist.destroy_process_group()
class DummyWandb:
@ -134,3 +151,37 @@ class DummyWandb:
pass
def finish(self):
pass
def resolve_autocast_dtype(
device_type: str,
requested_dtype: Optional[Union[str, torch.dtype]] = None
) -> torch.dtype:
"""Return a safe autocast dtype for the given device."""
if isinstance(requested_dtype, torch.dtype):
dtype = requested_dtype
elif isinstance(requested_dtype, str):
key = requested_dtype.lower()
if key in {"bfloat16", "bf16"}:
dtype = torch.bfloat16
elif key in {"float16", "fp16", "half"}:
dtype = torch.float16
elif key in {"float32", "fp32"}:
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype string: {requested_dtype}")
elif requested_dtype is None:
dtype = torch.bfloat16 if device_type == "cuda" else torch.float32
else:
raise TypeError(f"Unsupported dtype specifier type: {type(requested_dtype)!r}")
if device_type != "cuda" and dtype != torch.float32:
logger.warning(
"Falling back to float32 autocast on %s (requested %s unsupported)",
device_type,
dtype,
)
dtype = torch.float32
return dtype

View File

@ -6,7 +6,7 @@ from nanochat.common import get_dist_info
from nanochat.dataset import parquets_iter_batched
from nanochat.tokenizer import get_tokenizer
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
def tokenizing_distributed_data_loader(B, T, split, device=None, tokenizer_threads=4, tokenizer_batch_size=128):
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
assert split in ["train", "val"], "split must be 'train' or 'val'"
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
@ -16,7 +16,14 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
bos_token = tokenizer.get_bos_token_id()
# scratch buffer holds the tokens for one iteration
token_buffer = deque() # we stream tokens on the right and pop from the left
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
elif isinstance(device, str):
device = torch.device(device)
non_blocking = device.type == "cuda"
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=non_blocking)
# infinite iterator over document batches
def document_batches():
@ -29,6 +36,7 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
batches = document_batches()
batch_index = 0
while True:
# Accumulate enough tokens for one iteration before yielding.
while len(token_buffer) < needed_tokens:
@ -43,7 +51,7 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
# Create the inputs/targets as 1D tensors
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
# Reshape to 2D and move to target device
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=non_blocking)
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=non_blocking)
yield inputs, targets

View File

@ -169,8 +169,9 @@ class GPT(nn.Module):
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
self.transformer.wte.to(dtype=torch.bfloat16)
# Cast embeddings to bf16 only when the runtime supports it well (e.g. CUDA)
if torch.cuda.is_available():
self.transformer.wte.to(dtype=torch.bfloat16)
def init_weights(self):
self.apply(self._init_weights)
@ -210,7 +211,8 @@ class GPT(nn.Module):
# calculate the rotation frequencies at each (time, channel) pair
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
rotary_dtype = self.transformer.wte.weight.dtype
cos, sin = cos.to(dtype=rotary_dtype), sin.to(dtype=rotary_dtype)
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
return cos, sin
@ -262,7 +264,8 @@ class GPT(nn.Module):
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
expected_dtype = self.transformer.wte.weight.dtype
assert self.cos.dtype == expected_dtype, "Rotary embeddings dtype mismatch"
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length

View File

@ -19,7 +19,7 @@ import yaml
import pandas as pd
import torch
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, resolve_autocast_dtype
from nanochat.tokenizer import HuggingFaceTokenizer
from nanochat.checkpoint_manager import load_model
from nanochat.core_eval import evaluate_task
@ -122,7 +122,9 @@ def main():
# distributed / precision setup
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
device_type = device.type
autocast_dtype = resolve_autocast_dtype(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=autocast_dtype)
# Load model and tokenizer from command line or from file system
if len(sys.argv) >= 2:

View File

@ -9,7 +9,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
import os
import torch
from nanochat.checkpoint_manager import load_model
from nanochat.common import compute_init, print0, compute_cleanup
from nanochat.common import compute_init, print0, compute_cleanup, resolve_autocast_dtype
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.tokenizer import get_token_bytes
from nanochat.loss_eval import evaluate_bpb
@ -28,7 +28,9 @@ model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=mode
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
# Set up the precision we'll run with
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
device_type = device.type
autocast_dtype = resolve_autocast_dtype(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=autocast_dtype)
# Evaluate the loss on each split
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
@ -37,7 +39,7 @@ steps = split_tokens // tokens_per_step
token_bytes = get_token_bytes(device=device)
bpb_results = {}
for split_name in ["train", "val"]:
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name)
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
with autocast_ctx:
bpb = evaluate_bpb(model, loader, steps, token_bytes)
print0(f"{split_name} bpb: {bpb:.4f}")

View File

@ -2,10 +2,10 @@
Train model. Run as:
python base_train.py
"MFU %": f"{mfu:.2f}%" if mfu is not None else "N/A",
or distributed as:
torchrun --nproc_per_node=8 base_train.py
"Peak memory usage": f"{peak_memory_mib:.2f}MiB" if peak_memory_mib is not None else "N/A",
"""
import os
@ -16,7 +16,7 @@ import torch
from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, resolve_autocast_dtype
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
@ -58,8 +58,11 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
# Compute init
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = device.type
is_cuda = device_type == "cuda"
autocast_dtype = resolve_autocast_dtype(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=autocast_dtype)
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
@ -96,7 +99,7 @@ model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_la
with torch.device("meta"):
model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config)
model.to_empty(device="cuda")
model.to_empty(device=device)
model.init_weights()
orig_model = model # original, uncompiled model, for saving raw model state_dict
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
@ -133,8 +136,8 @@ adamw_optimizer, muon_optimizer = optimizers
# Initialize the DataLoaders for train/val
base_dir = get_base_dir()
tokens_dir = os.path.join(base_dir, "tokenized_data")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
x, y = next(train_loader) # kick off load of the very first batch of data
# -----------------------------------------------------------------------------
@ -165,16 +168,22 @@ def get_muon_momentum(it):
# -----------------------------------------------------------------------------
# Training loop
min_val_bpb = float("inf")
last_val_bpb = float("nan")
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
mfu = None
core_results = {"core_metric": float("nan"), "centered_results": {}}
eval_enabled = eval_every > 0 and eval_tokens > 0
core_metric_enabled = core_metric_every > 0 and core_metric_max_per_task > 0
sample_enabled = sample_every > 0
# note that we run +1 steps only so that we can eval and save at the end
for step in range(num_iterations + 1):
last_step = step == num_iterations
flops_so_far = num_flops_per_token * total_batch_size * step
# once in a while: evaluate the val bpb (all ranks participate)
if last_step or step % eval_every == 0:
if eval_enabled and (last_step or step % eval_every == 0):
model.eval()
val_loader = build_val_loader()
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
@ -183,6 +192,7 @@ for step in range(num_iterations + 1):
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb:
min_val_bpb = val_bpb
last_val_bpb = val_bpb
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
@ -193,22 +203,22 @@ for step in range(num_iterations + 1):
# once in a while: estimate the CORE metric (all ranks participate)
# use the original uncompiled model because the inputs keep changing shape
if last_step or (step > 0 and step % core_metric_every == 0):
if core_metric_enabled and (last_step or (step > 0 and step % core_metric_every == 0)):
model.eval()
with autocast_ctx:
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
core_results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {core_results['core_metric']:.4f}")
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"core_metric": results["core_metric"],
"centered_results": results["centered_results"],
"core_metric": core_results["core_metric"],
"centered_results": core_results["centered_results"],
})
model.train()
# once in a while: sample from the model (only on master process)
# use the original uncompiled model because the inputs keep changing shape
if master_process and (last_step or (step > 0 and step % sample_every == 0)):
if sample_enabled and master_process and (last_step or (step > 0 and step % sample_every == 0)):
model.eval()
prompts = [
"The capital of France is",
@ -252,7 +262,8 @@ for step in range(num_iterations + 1):
# -------------------------------------------------------------------------
# single training step
# evaluate the gradient
torch.cuda.synchronize()
if is_cuda:
torch.cuda.synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
@ -260,7 +271,7 @@ for step in range(num_iterations + 1):
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
x, y = next(train_loader) # prefetch the next batch while the accelerator is busy with forward/backward
# gradient clipping (TODO possibly expertiment with)
if grad_clip > 0.0:
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
@ -275,7 +286,8 @@ for step in range(num_iterations + 1):
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
torch.cuda.synchronize()
if is_cuda:
torch.cuda.synchronize()
t1 = time.time()
dt = t1 - t0
# -------------------------------------------------------------------------
@ -286,13 +298,14 @@ for step in range(num_iterations + 1):
pct_done = 100 * step / num_iterations
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
promised_flops_per_sec = 989e12 * ddp_world_size if is_cuda else None # bfloat16 H100 baseline
mfu = 100 * flops_per_sec / promised_flops_per_sec if promised_flops_per_sec else None
mfu_display = f"{mfu:.2f}" if mfu is not None else "N/A"
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu_display} | total time: {total_training_time/60:.2f}m")
if step % 100 == 0:
wandb_run.log({
log_payload = {
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
@ -300,11 +313,17 @@ for step in range(num_iterations + 1):
"train/lrm": lrm,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
})
}
if mfu is not None:
log_payload["train/mfu"] = mfu
wandb_run.log(log_payload)
# print a few more stats
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
peak_memory_mib = (torch.cuda.max_memory_allocated() / 1024 / 1024) if is_cuda else None
if peak_memory_mib is not None:
print0(f"Peak memory usage: {peak_memory_mib:.2f}MiB")
else:
print0("Peak memory usage: N/A (CPU run)")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
@ -325,12 +344,12 @@ get_report().log(section="Base model training", data=[
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
"Final validation bpb": val_bpb,
"CORE metric estimate": results["core_metric"],
"MFU %": f"{mfu:.2f}%",
"Final validation bpb": last_val_bpb,
"CORE metric estimate": core_results["core_metric"],
"MFU %": f"{mfu:.2f}%" if mfu is not None else "N/A",
"Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time/60:.2f}m",
"Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB",
"Peak memory usage": f"{peak_memory_mib:.2f}MiB" if peak_memory_mib is not None else "N/A",
}
])

View File

@ -6,7 +6,7 @@ python -m scripts.chat_cli -i mid
"""
import argparse
import torch
from nanochat.common import compute_init
from nanochat.common import compute_init, resolve_autocast_dtype
from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model
@ -21,7 +21,9 @@ args = parser.parse_args()
# Init the model and tokenizer
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
device_type = device.type
autocast_dtype = resolve_autocast_dtype(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=autocast_dtype)
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
# Special tokens for the chat state machine

View File

@ -15,6 +15,7 @@ import torch
import torch.distributed as dist
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0
from nanochat.common import resolve_autocast_dtype
from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine
@ -194,8 +195,9 @@ if __name__ == "__main__":
args = parser.parse_args()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype)
device_type = device.type
autocast_dtype = resolve_autocast_dtype(device_type, args.dtype)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=autocast_dtype)
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
engine = Engine(model, tokenizer)

View File

@ -23,7 +23,7 @@ import wandb
import torch
import torch.distributed as dist
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, resolve_autocast_dtype
from nanochat.checkpoint_manager import save_checkpoint, load_model
from nanochat.engine import Engine
from tasks.gsm8k import GSM8K
@ -55,9 +55,10 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
# Init compute/precision
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = device.type
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
autocast_dtype = resolve_autocast_dtype(device_type, dtype)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=autocast_dtype)
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process

View File

@ -17,7 +17,7 @@ import wandb
import torch
import torch.distributed as dist
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, resolve_autocast_dtype
from nanochat.checkpoint_manager import load_model
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.engine import Engine
@ -61,9 +61,10 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
# Compute init
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = device.type
master_process = ddp_rank == 0
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
autocast_dtype = resolve_autocast_dtype(device_type, dtype)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=autocast_dtype)
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process

View File

@ -16,7 +16,7 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator
from nanochat.common import compute_init
from nanochat.common import compute_init, resolve_autocast_dtype
from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine
@ -32,7 +32,9 @@ parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind th
args = parser.parse_args()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
device_type = device.type
autocast_dtype = resolve_autocast_dtype(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=autocast_dtype)
class ChatMessage(BaseModel):
role: str

View File

@ -16,7 +16,7 @@ import time
import wandb
import torch
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, resolve_autocast_dtype
from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
@ -51,9 +51,11 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
# Compute init
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = device.type
is_cuda = device_type == "cuda"
master_process = ddp_rank == 0
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
autocast_dtype = resolve_autocast_dtype(device_type, dtype)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=autocast_dtype)
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
@ -129,8 +131,8 @@ def mid_data_generator(split):
scratch[i] = token_buffer.popleft()
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=is_cuda)
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=is_cuda)
if split == "train":
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
yield inputs, targets
@ -156,6 +158,7 @@ min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
mfu = None
step = 0
while True:
flops_so_far = num_flops_per_token * total_batch_size * step
@ -214,7 +217,8 @@ while True:
# -------------------------------------------------------------------------
# single training step
# evaluate the gradient
torch.cuda.synchronize()
if is_cuda:
torch.cuda.synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
@ -222,7 +226,7 @@ while True:
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
x, y = next(train_loader) # prefetch the next batch while the accelerator is busy with forward/backward
progress = max(progress, approx_progress) # only increase progress monotonically
# step the optimizers
lrm = get_lr_multiplier(progress)
@ -235,7 +239,8 @@ while True:
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
torch.cuda.synchronize()
if is_cuda:
torch.cuda.synchronize()
t1 = time.time()
dt = t1 - t0
# -------------------------------------------------------------------------
@ -249,13 +254,14 @@ while True:
pct_done = 100 * progress
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
promised_flops_per_sec = 989e12 * ddp_world_size if is_cuda else None # bfloat16 H100 baseline
mfu = 100 * flops_per_sec / promised_flops_per_sec if promised_flops_per_sec else None
mfu_display = f"{mfu:.2f}" if mfu is not None else "N/A"
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu_display} | total time: {total_training_time/60:.2f}m")
if step % 10 == 0:
wandb_run.log({
log_payload = {
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
@ -263,11 +269,17 @@ while True:
"train/lrm": lrm,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
})
}
if mfu is not None:
log_payload["train/mfu"] = mfu
wandb_run.log(log_payload)
# print a few more stats
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
peak_memory_mib = (torch.cuda.max_memory_allocated() / 1024 / 1024) if is_cuda else None
if peak_memory_mib is not None:
print0(f"Peak memory usage: {peak_memory_mib:.2f}MiB")
else:
print0("Peak memory usage: N/A (CPU run)")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")