trying to add basic cpu support, will try mps too

This commit is contained in:
Andrej Karpathy 2025-10-16 16:14:38 +00:00
parent 4346536ab2
commit 722da4f543
3 changed files with 26 additions and 18 deletions

View File

@ -89,14 +89,15 @@ def get_dist_info():
else:
return False, 0, 0, 1
def compute_init():
def compute_init(device_type="cuda"): # cuda|cpu
"""Basic initialization that we keep doing over and over, so make common."""
# CUDA is currently required
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
# assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
# Reproducibility
torch.manual_seed(42)
if device_type == "cuda":
torch.cuda.manual_seed(42)
# skipping full reproducibility for now, possibly investigate slowdown later
# torch.use_deterministic_algorithms(True)
@ -106,15 +107,15 @@ def compute_init():
# Precision
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp:
if ddp and device_type == "cuda":
device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
else:
device = torch.device("cuda")
device = torch.device(device_type) # cuda|cpu
if ddp_rank == 0:
logger.info(f"Distributed world size: {ddp_world_size}")

View File

@ -6,7 +6,7 @@ from nanochat.common import get_dist_info
from nanochat.dataset import parquets_iter_batched
from nanochat.tokenizer import get_tokenizer
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
assert split in ["train", "val"], "split must be 'train' or 'val'"
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
@ -44,6 +44,6 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
yield inputs, targets

View File

@ -6,6 +6,9 @@ python base_train.py
or distributed as:
torchrun --nproc_per_node=8 base_train.py
If you just want to see it run on CPU (you won't get far but it should run), try something like:
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --device_type=cpu --eval_tokens=512 --total_batch_size=512 --num_iterations=1000
"""
import os
@ -27,6 +30,8 @@ print_banner()
# -----------------------------------------------------------------------------
# User settings
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
# Runtime
device_type = "cuda" # cuda|cpu
# Model architecture
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
max_seq_len = 2048 # max context length
@ -57,9 +62,11 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
# -----------------------------------------------------------------------------
# Compute init
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
@ -96,7 +103,7 @@ model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_la
with torch.device("meta"):
model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config)
model.to_empty(device="cuda")
model.to_empty(device=device)
model.init_weights()
orig_model = model # original, uncompiled model, for saving raw model state_dict
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
@ -133,8 +140,8 @@ adamw_optimizer, muon_optimizer = optimizers
# Initialize the DataLoaders for train/val
base_dir = get_base_dir()
tokens_dir = os.path.join(base_dir, "tokenized_data")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
x, y = next(train_loader) # kick off load of the very first batch of data
# -----------------------------------------------------------------------------
@ -252,7 +259,7 @@ for step in range(num_iterations + 1):
# -------------------------------------------------------------------------
# single training step
# evaluate the gradient
torch.cuda.synchronize()
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
@ -275,7 +282,7 @@ for step in range(num_iterations + 1):
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
torch.cuda.synchronize()
synchronize()
t1 = time.time()
dt = t1 - t0
# -------------------------------------------------------------------------
@ -304,7 +311,7 @@ for step in range(num_iterations + 1):
})
# print a few more stats
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
@ -330,7 +337,7 @@ get_report().log(section="Base model training", data=[
"MFU %": f"{mfu:.2f}%",
"Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time/60:.2f}m",
"Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB",
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
}
])