mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
trying to add basic cpu support, will try mps too
This commit is contained in:
parent
4346536ab2
commit
722da4f543
|
|
@ -89,15 +89,16 @@ def get_dist_info():
|
|||
else:
|
||||
return False, 0, 0, 1
|
||||
|
||||
def compute_init():
|
||||
def compute_init(device_type="cuda"): # cuda|cpu
|
||||
"""Basic initialization that we keep doing over and over, so make common."""
|
||||
|
||||
# CUDA is currently required
|
||||
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
|
||||
# assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
|
||||
|
||||
# Reproducibility
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
if device_type == "cuda":
|
||||
torch.cuda.manual_seed(42)
|
||||
# skipping full reproducibility for now, possibly investigate slowdown later
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
|
|
@ -106,15 +107,15 @@ def compute_init():
|
|||
# Precision
|
||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
if ddp:
|
||||
if ddp and device_type == "cuda":
|
||||
device = torch.device("cuda", ddp_local_rank)
|
||||
torch.cuda.set_device(device) # make "cuda" default to this device
|
||||
dist.init_process_group(backend="nccl", device_id=device)
|
||||
dist.barrier()
|
||||
else:
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(device_type) # cuda|cpu
|
||||
|
||||
if ddp_rank == 0:
|
||||
logger.info(f"Distributed world size: {ddp_world_size}")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from nanochat.common import get_dist_info
|
|||
from nanochat.dataset import parquets_iter_batched
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
|
||||
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
|
||||
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
|
||||
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
|
|
@ -44,6 +44,6 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
|||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
# Reshape to 2D and move to GPU async
|
||||
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
|
||||
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
yield inputs, targets
|
||||
|
|
|
|||
|
|
@ -6,6 +6,9 @@ python base_train.py
|
|||
or distributed as:
|
||||
|
||||
torchrun --nproc_per_node=8 base_train.py
|
||||
|
||||
If you just want to see it run on CPU (you won't get far but it should run), try something like:
|
||||
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --device_type=cpu --eval_tokens=512 --total_batch_size=512 --num_iterations=1000
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -27,6 +30,8 @@ print_banner()
|
|||
# -----------------------------------------------------------------------------
|
||||
# User settings
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
# Runtime
|
||||
device_type = "cuda" # cuda|cpu
|
||||
# Model architecture
|
||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
max_seq_len = 2048 # max context length
|
||||
|
|
@ -57,9 +62,11 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
|
|||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
@ -96,7 +103,7 @@ model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_la
|
|||
with torch.device("meta"):
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
model = GPT(model_config)
|
||||
model.to_empty(device="cuda")
|
||||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
||||
|
|
@ -133,8 +140,8 @@ adamw_optimizer, muon_optimizer = optimizers
|
|||
# Initialize the DataLoaders for train/val
|
||||
base_dir = get_base_dir()
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val")
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
|
||||
x, y = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -252,7 +259,7 @@ for step in range(num_iterations + 1):
|
|||
# -------------------------------------------------------------------------
|
||||
# single training step
|
||||
# evaluate the gradient
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
|
|
@ -275,7 +282,7 @@ for step in range(num_iterations + 1):
|
|||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
|
@ -304,7 +311,7 @@ for step in range(num_iterations + 1):
|
|||
})
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
|
|
@ -330,7 +337,7 @@ get_report().log(section="Base model training", data=[
|
|||
"MFU %": f"{mfu:.2f}%",
|
||||
"Total training flops": f"{flops_so_far:e}",
|
||||
"Total training time": f"{total_training_time/60:.2f}m",
|
||||
"Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB",
|
||||
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
|
||||
}
|
||||
])
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user