diff --git a/nanochat/common.py b/nanochat/common.py index 8b10df9..22232d1 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -89,15 +89,16 @@ def get_dist_info(): else: return False, 0, 0, 1 -def compute_init(): +def compute_init(device_type="cuda"): # cuda|cpu """Basic initialization that we keep doing over and over, so make common.""" # CUDA is currently required - assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm" + # assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm" # Reproducibility torch.manual_seed(42) - torch.cuda.manual_seed(42) + if device_type == "cuda": + torch.cuda.manual_seed(42) # skipping full reproducibility for now, possibly investigate slowdown later # torch.use_deterministic_algorithms(True) # torch.backends.cudnn.deterministic = True @@ -106,15 +107,15 @@ def compute_init(): # Precision torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls - # Distributed setup: Distributed Data Parallel (DDP), optional + # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() - if ddp: + if ddp and device_type == "cuda": device = torch.device("cuda", ddp_local_rank) torch.cuda.set_device(device) # make "cuda" default to this device dist.init_process_group(backend="nccl", device_id=device) dist.barrier() else: - device = torch.device("cuda") + device = torch.device(device_type) # cuda|cpu if ddp_rank == 0: logger.info(f"Distributed world size: {ddp_world_size}") diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index c1636b1..12e7d8e 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -6,7 +6,7 @@ from nanochat.common import get_dist_info from nanochat.dataset import parquets_iter_batched from nanochat.tokenizer import get_tokenizer -def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128): +def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"): """Stream pretraining text from parquet files, tokenize, yield training batches.""" assert split in ["train", "val"], "split must be 'train' or 'val'" ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() @@ -44,6 +44,6 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz inputs_cpu = scratch[:-1].to(dtype=torch.int32) targets_cpu = scratch[1:] # Reshape to 2D and move to GPU async - inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True) - targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True) + inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True) + targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True) yield inputs, targets diff --git a/scripts/base_train.py b/scripts/base_train.py index b691ed4..166e11e 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -6,6 +6,9 @@ python base_train.py or distributed as: torchrun --nproc_per_node=8 base_train.py + +If you just want to see it run on CPU (you won't get far but it should run), try something like: +python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --device_type=cpu --eval_tokens=512 --total_batch_size=512 --num_iterations=1000 """ import os @@ -27,6 +30,8 @@ print_banner() # ----------------------------------------------------------------------------- # User settings run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) +# Runtime +device_type = "cuda" # cuda|cpu # Model architecture depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived max_seq_len = 2048 # max context length @@ -57,9 +62,11 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin # ----------------------------------------------------------------------------- # Compute init -ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() +ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) +autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) +synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None +get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 # wandb logging init use_dummy_wandb = run == "dummy" or not master_process @@ -96,7 +103,7 @@ model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_la with torch.device("meta"): model_config = GPTConfig(**model_config_kwargs) model = GPT(model_config) -model.to_empty(device="cuda") +model.to_empty(device=device) model.init_weights() orig_model = model # original, uncompiled model, for saving raw model state_dict model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through @@ -133,8 +140,8 @@ adamw_optimizer, muon_optimizer = optimizers # Initialize the DataLoaders for train/val base_dir = get_base_dir() tokens_dir = os.path.join(base_dir, "tokenized_data") -train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train") -build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val") +train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device) +build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device) x, y = next(train_loader) # kick off load of the very first batch of data # ----------------------------------------------------------------------------- @@ -252,7 +259,7 @@ for step in range(num_iterations + 1): # ------------------------------------------------------------------------- # single training step # evaluate the gradient - torch.cuda.synchronize() + synchronize() t0 = time.time() for micro_step in range(grad_accum_steps): with autocast_ctx: @@ -275,7 +282,7 @@ for step in range(num_iterations + 1): for opt in optimizers: opt.step() model.zero_grad(set_to_none=True) - torch.cuda.synchronize() + synchronize() t1 = time.time() dt = t1 - t0 # ------------------------------------------------------------------------- @@ -304,7 +311,7 @@ for step in range(num_iterations + 1): }) # print a few more stats -print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB") +print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m") print0(f"Minimum validation bpb: {min_val_bpb:.4f}") @@ -330,7 +337,7 @@ get_report().log(section="Base model training", data=[ "MFU %": f"{mfu:.2f}%", "Total training flops": f"{flops_so_far:e}", "Total training time": f"{total_training_time/60:.2f}m", - "Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB", + "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB", } ])