diff --git a/nanochat/common.py b/nanochat/common.py index 9bcd5dd..dcefd04 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -96,6 +96,12 @@ def download_file_with_lock(url, filename, postprocess_fn=None): def print0(s="",**kwargs): ddp_rank = int(os.environ.get('RANK', 0)) + if 'RANK' not in os.environ: + try: + import torch_xla.core.xla_model as xm + ddp_rank = int(xm.get_ordinal()) + except Exception: + ddp_rank = 0 if ddp_rank == 0: print(s, **kwargs) @@ -136,8 +142,20 @@ def get_dist_info(): ddp_local_rank = int(os.environ['LOCAL_RANK']) ddp_world_size = int(os.environ['WORLD_SIZE']) return True, ddp_rank, ddp_local_rank, ddp_world_size - else: - return False, 0, 0, 1 + + # TPU/XLA uses torch_xla multiprocessing rather than torchrun. + # If we're in an XLA runtime, expose rank/world size so dataloaders shard correctly. + try: + import torch_xla.core.xla_model as xm + ddp_world_size = int(xm.xrt_world_size()) + if ddp_world_size > 1: + ddp_rank = int(xm.get_ordinal()) + ddp_local_rank = int(xm.get_local_ordinal()) + return True, ddp_rank, ddp_local_rank, ddp_world_size + except Exception: + pass + + return False, 0, 0, 1 def autodetect_device_type(): # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU @@ -150,14 +168,19 @@ def autodetect_device_type(): print0(f"Autodetected device type: {device_type}") return device_type -def compute_init(device_type="cuda"): # cuda|cpu|mps +def compute_init(device_type="cuda"): # cuda|cpu|mps|xla """Basic initialization that we keep doing over and over, so make common.""" - assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm" + assert device_type in ["cuda", "mps", "cpu", "xla"], "Invalid device type atm" if device_type == "cuda": assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'" if device_type == "mps": assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'" + if device_type == "xla": + try: + import torch_xla.core.xla_model as xm + except Exception as e: + raise AssertionError("device_type is 'xla' but torch_xla is not available") from e # Reproducibility # Note that we set the global seeds here, but most of the code uses explicit rng objects. @@ -179,6 +202,9 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps torch.cuda.set_device(device) # make "cuda" default to this device dist.init_process_group(backend="nccl", device_id=device) dist.barrier() + elif device_type == "xla": + import torch_xla.core.xla_model as xm + device = xm.xla_device() else: device = torch.device(device_type) # mps|cpu diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 125625f..4ab3e6c 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -109,14 +109,21 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit( # Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)] # This gives us contiguous views and a single HtoD transfer - use_cuda = device == "cuda" + device_str = str(device) + use_cuda = device == "cuda" or device_str == "cuda" + use_xla = device_str.startswith("xla") row_buffer = torch.empty((B, row_capacity), dtype=torch.long) # for building rows without creating Python lists cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU) - gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience cpu_targets = cpu_buffer[B * T:].view(B, T) - inputs = gpu_buffer[:B * T].view(B, T) - targets = gpu_buffer[B * T:].view(B, T) + if use_xla: + import torch_xla.core.xla_model as xm + inputs = None + targets = None + else: + gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer + inputs = gpu_buffer[:B * T].view(B, T) + targets = gpu_buffer[B * T:].view(B, T) while True: for row_idx in range(B): @@ -156,8 +163,14 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit( state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch} # Single HtoD copy into persistent GPU buffer and yield - gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda) - yield inputs, targets, state_dict + if use_xla: + gpu_buffer = xm.send_cpu_data_to_device(cpu_buffer, device) + inputs = gpu_buffer[:B * T].view(B, T) + targets = gpu_buffer[B * T:].view(B, T) + yield inputs, targets, state_dict + else: + gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda) + yield inputs, targets, state_dict def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs): """Helper that omits state_dict from yields.""" diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 208acd1..b4e5794 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -253,7 +253,10 @@ class GPT(nn.Module): # calculate the rotation frequencies at each (time, channel) pair freqs = torch.outer(t, inv_freq) cos, sin = freqs.cos(), freqs.sin() - cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16 + if device.type == "mps": + cos, sin = cos.float(), sin.float() # avoid bf16/fp32 mixing issues on MPSGraph + else: + cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16 cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting return cos, sin @@ -391,7 +394,10 @@ class GPT(nn.Module): # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2)) assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" - assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" + if self.cos.device.type == "mps": + assert self.cos.dtype == torch.float32, "Rotary embeddings must be float32 on MPS" + else: + assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache T0 = 0 if kv_cache is None else kv_cache.get_pos() cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length diff --git a/nanochat/loss_eval.py b/nanochat/loss_eval.py index 5a556e6..3092f7d 100644 --- a/nanochat/loss_eval.py +++ b/nanochat/loss_eval.py @@ -4,6 +4,7 @@ A number of functions that help with evaluating a base model. import math import torch import torch.distributed as dist +from nanochat.common import get_dist_info @torch.no_grad() def evaluate_bpb(model, batches, steps, token_bytes): @@ -52,10 +53,16 @@ def evaluate_bpb(model, batches, steps, token_bytes): total_nats += (loss2d * (num_bytes2d > 0)).sum() total_bytes += num_bytes2d.sum() # sum reduce across all ranks - world_size = dist.get_world_size() if dist.is_initialized() else 1 + _, _, _, world_size = get_dist_info() if world_size > 1: - dist.all_reduce(total_nats, op=dist.ReduceOp.SUM) - dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM) + if dist.is_initialized(): + dist.all_reduce(total_nats, op=dist.ReduceOp.SUM) + dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM) + else: + # TPU/XLA path (multi-process without torch.distributed process group) + import torch_xla.core.xla_model as xm + total_nats = xm.all_reduce(xm.REDUCE_SUM, total_nats) + total_bytes = xm.all_reduce(xm.REDUCE_SUM, total_bytes) # move both to cpu, calculate bpb and return total_nats = total_nats.item() total_bytes = total_bytes.item() diff --git a/nanochat/optim.py b/nanochat/optim.py index 190a1ed..c6f6d55 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -48,6 +48,34 @@ def adamw_step_fused( step_size = lr_t / bias1 p.add_(exp_avg / denom, alpha=-step_size) +def adamw_step_eager( + p: Tensor, + grad: Tensor, + exp_avg: Tensor, + exp_avg_sq: Tensor, + step_t: Tensor, + lr_t: Tensor, + beta1_t: Tensor, + beta2_t: Tensor, + eps_t: Tensor, + wd_t: Tensor, +) -> None: + # Same math as adamw_step_fused but without torch.compile. + step_t = step_t.to(device=p.device) + lr_t = lr_t.to(device=p.device) + beta1_t = beta1_t.to(device=p.device) + beta2_t = beta2_t.to(device=p.device) + eps_t = eps_t.to(device=p.device) + wd_t = wd_t.to(device=p.device) + p.mul_(1 - lr_t * wd_t) + exp_avg.lerp_(grad, 1 - beta1_t) + exp_avg_sq.lerp_(grad.square(), 1 - beta2_t) + bias1 = 1 - beta1_t ** step_t + bias2 = 1 - beta2_t ** step_t + denom = (exp_avg_sq / bias2).sqrt() + eps_t + step_size = lr_t / bias1 + p.add_(exp_avg / denom, alpha=-step_size) + # ----------------------------------------------------------------------------- """ Muon optimizer adapted and simplified from modded-nanogpt. @@ -108,7 +136,12 @@ def muon_step_fused( g = stacked_grads.lerp_(momentum_buffer, momentum) # Polar express - X = g.bfloat16() + # MPS has limited/fragile bfloat16 support and will assert if BF16 and FP32 mix + # inside MPSGraph. Keep Muon math in FP32 on MPS for stability. + if stacked_grads.device.type == "mps": + X = g.float() + else: + X = g.bfloat16() X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) if g.size(-2) > g.size(-1): # Tall matrix for a, b, c in polar_express_coeffs[:ns_steps]: @@ -120,7 +153,7 @@ def muon_step_fused( A = X @ X.mT B = b * A + c * (A @ A) X = a * X + B @ X - g = X + g = X.to(dtype=stacked_params.dtype) # Variance reduction beta2 = beta2_t.to(g.dtype) @@ -141,6 +174,58 @@ def muon_step_fused( mask = (g * stacked_params) >= 0 stacked_params.sub_(lr * g + lr * wd * stacked_params * mask) +def muon_step_eager( + stacked_grads: Tensor, + stacked_params: Tensor, + momentum_buffer: Tensor, + second_momentum_buffer: Tensor, + momentum_t: Tensor, + lr_t: Tensor, + wd_t: Tensor, + beta2_t: Tensor, + ns_steps: int, + red_dim: int, +) -> None: + # Same math as muon_step_fused but without torch.compile. + momentum_t = momentum_t.to(device=stacked_grads.device) + lr_t = lr_t.to(device=stacked_grads.device) + wd_t = wd_t.to(device=stacked_grads.device) + beta2_t = beta2_t.to(device=stacked_grads.device) + momentum = momentum_t.to(stacked_grads.dtype) + momentum_buffer.lerp_(stacked_grads, 1 - momentum) + g = stacked_grads.lerp_(momentum_buffer, momentum) + + X = g.bfloat16() + X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) + if g.size(-2) > g.size(-1): + for a, b, c in polar_express_coeffs[:ns_steps]: + A = X.mT @ X + B = b * A + c * (A @ A) + X = a * X + X @ B + else: + for a, b, c in polar_express_coeffs[:ns_steps]: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + g = X.to(dtype=stacked_params.dtype) + + beta2 = beta2_t.to(g.dtype) + v_mean = g.float().square().mean(dim=red_dim, keepdim=True) + red_dim_size = g.size(red_dim) + v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size + v_norm = v_norm_sq.sqrt() + second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) + step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt() + scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() + v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt() + final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10)) + g = g * final_scale.to(g.dtype) + + lr = lr_t.to(g.dtype) + wd = wd_t.to(g.dtype) + mask = (g * stacked_params) >= 0 + stacked_params.sub_(lr * g + lr * wd * stacked_params * mask) + # ----------------------------------------------------------------------------- # Single GPU version of the MuonAdamW optimizer. # Used mostly for reference, debugging and testing. @@ -216,11 +301,18 @@ class MuonAdamW(torch.optim.Optimizer): self._adamw_wd_t.fill_(group['weight_decay']) # Fused update: weight_decay -> momentum -> bias_correction -> param_update - adamw_step_fused( - p, grad, exp_avg, exp_avg_sq, - self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t, - self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t, - ) + if p.device.type == "mps": + adamw_step_eager( + p, grad, exp_avg, exp_avg_sq, + self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t, + self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t, + ) + else: + adamw_step_fused( + p, grad, exp_avg, exp_avg_sq, + self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t, + self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t, + ) def _step_muon(self, group: dict) -> None: """ @@ -260,18 +352,32 @@ class MuonAdamW(torch.optim.Optimizer): self._muon_wd_t.fill_(group["weight_decay"]) # Single fused kernel: momentum -> polar_express -> variance_reduction -> update - muon_step_fused( - stacked_grads, - stacked_params, - momentum_buffer, - second_momentum_buffer, - self._muon_momentum_t, - self._muon_lr_t, - self._muon_wd_t, - self._muon_beta2_t, - group["ns_steps"], - red_dim, - ) + if device.type == "mps": + muon_step_eager( + stacked_grads, + stacked_params, + momentum_buffer, + second_momentum_buffer, + self._muon_momentum_t, + self._muon_lr_t, + self._muon_wd_t, + self._muon_beta2_t, + group["ns_steps"], + red_dim, + ) + else: + muon_step_fused( + stacked_grads, + stacked_params, + momentum_buffer, + second_momentum_buffer, + self._muon_momentum_t, + self._muon_lr_t, + self._muon_wd_t, + self._muon_beta2_t, + group["ns_steps"], + red_dim, + ) # Copy back to original params torch._foreach_copy_(params, list(stacked_params.unbind(0))) diff --git a/scripts/base_train.py b/scripts/base_train.py index fa05b60..5bac21e 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -17,12 +17,15 @@ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import argparse import time from contextlib import nullcontext, contextmanager +import runpy -import wandb import torch +try: + import wandb +except ModuleNotFoundError: + wandb = None from nanochat.gpt import GPT, GPTConfig -from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint @@ -38,7 +41,7 @@ parser = argparse.ArgumentParser(description="Pretrain base model") # Logging parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime -parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") +parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps|xla (empty = autodetect)") # FP8 training parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)") parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") @@ -75,16 +78,66 @@ parser.add_argument("--sample-every", type=int, default=2000, help="sample from parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") # Output parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") +parser.add_argument("--smoke-test", action="store_true", help="run a tiny fast configuration for debugging (overrides several training args)") args = parser.parse_args() user_config = vars(args).copy() # for logging # ----------------------------------------------------------------------------- # Compute init device_type = autodetect_device_type() if args.device_type == "" else args.device_type + +# TPU/XLA launcher: XLA uses multi-process execution via xmp.spawn rather than torchrun. +# When device_type is xla and we're not already in a spawned worker, spawn and re-run this module. +if device_type == "xla" and os.environ.get("NANOCHAT_XLA_SPAWNED", "0") != "1": + try: + import torch_xla.distributed.xla_multiprocessing as xmp + import torch_xla.core.xla_model as xm + except Exception as e: + raise RuntimeError("--device-type xla requested but torch_xla is not available") from e + + def _xla_worker(_index, argv): + os.environ["NANOCHAT_XLA_SPAWNED"] = "1" + # Ensure child sees the same argv (xmp.spawn does not preserve it by default in all envs) + import sys + sys.argv = argv + runpy.run_module("scripts.base_train", run_name="__main__") + + xmp.spawn(_xla_worker, args=(os.sys.argv,), nprocs=xm.xrt_world_size()) + raise SystemExit(0) + ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) + +if args.smoke_test: + args.run = "dummy" + args.depth = 2 + args.aspect_ratio = 32 + args.head_dim = 64 + args.max_seq_len = 128 + args.device_batch_size = 1 + args.total_batch_size = args.device_batch_size * args.max_seq_len * ddp_world_size + args.num_iterations = 5 + args.target_flops = -1.0 + args.target_param_data_ratio = -1 + args.eval_every = -1 + args.core_metric_every = -1 + args.sample_every = -1 + args.save_every = -1 + args.resume_from_step = -1 + args.fp8 = False + user_config = vars(args).copy() # refresh for logging + master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() -synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None +if device_type == "cuda": + autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) +elif device_type == "xla": + from torch_xla.amp import autocast as xla_autocast + autocast_ctx = xla_autocast(dtype=torch.bfloat16) +else: + autocast_ctx = nullcontext() +synchronize = torch.cuda.synchronize if device_type == "cuda" else (lambda: None) +if device_type == "xla": + import torch_xla.core.xla_model as xm + synchronize = xm.mark_step get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 if device_type == "cuda": gpu_device_name = torch.cuda.get_device_name(0) @@ -95,7 +148,7 @@ else: # wandb logging init use_dummy_wandb = args.run == "dummy" or not master_process -wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config) +wandb_run = DummyWandb() if (use_dummy_wandb or wandb is None) else wandb.init(project="nanochat", name=args.run, config=user_config) # Flash Attention status if HAS_FA3: @@ -110,10 +163,16 @@ else: print0("!" * 80) # Tokenizer will be useful for evaluation, also we need the vocab size -tokenizer = get_tokenizer() -token_bytes = get_token_bytes(device=device) -vocab_size = tokenizer.get_vocab_size() -print0(f"Vocab size: {vocab_size:,}") +if args.smoke_test: + tokenizer = None + vocab_size = 8192 + token_bytes = torch.ones(vocab_size, dtype=torch.int64, device=device) + print0(f"Vocab size: {vocab_size:,}") +else: + tokenizer = get_tokenizer() + token_bytes = get_token_bytes(device=device) + vocab_size = tokenizer.get_vocab_size() + print0(f"Vocab size: {vocab_size:,}") # Model kwargs are derived from the desired depth of the model # We nudge model_dim up to the nearest multiple of head_dim to ensure clean division @@ -291,7 +350,12 @@ def disable_fp8(model): # Compile the model orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) -model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe +if device_type == "xla" and os.environ.get("NANOCHAT_XLA_TORCH_COMPILE", "0") != "1": + model = model +elif device_type == "mps" and os.environ.get("NANOCHAT_MPS_TORCH_COMPILE", "0") != "1": + model = model +else: + model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe # ----------------------------------------------------------------------------- # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) @@ -312,9 +376,23 @@ if resuming: # ----------------------------------------------------------------------------- # Initialize the DataLoaders for train/val dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] -train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) -build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device) -x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data +if args.smoke_test: + # Synthetic data path for local debugging: avoids requiring pyarrow + dataset parquet files. + vocab_size_for_smoke = vocab_size + def _synthetic_loader(): + while True: + x = torch.randint(0, vocab_size_for_smoke, (args.device_batch_size, args.max_seq_len), device=device, dtype=torch.long) + y = torch.randint(0, vocab_size_for_smoke, (args.device_batch_size, args.max_seq_len), device=device, dtype=torch.long) + state_dict = {"pq_idx": 0, "rg_idx": 0, "epoch": 1} + yield x, y, state_dict + train_loader = _synthetic_loader() + build_val_loader = lambda: iter(()) + x, y, dataloader_state_dict = next(train_loader) +else: + from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit + train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) + build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device) + x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data # ----------------------------------------------------------------------------- # Set up hyperparameter schedulers @@ -421,7 +499,13 @@ while True: model.train() # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step - if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0): + if (not args.smoke_test) and (last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0)): + if ddp: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() + elif device_type == "xla": + import torch_xla.core.xla_model as xm + xm.rendezvous("nanochat_base_train_checkpoint") save_checkpoint( checkpoint_dir, step, @@ -443,6 +527,12 @@ while True: }, rank=ddp_rank, ) + if ddp: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() + elif device_type == "xla": + import torch_xla.core.xla_model as xm + xm.rendezvous("nanochat_base_train_checkpoint_done") # termination conditions (TODO: possibly also add loss explosions etc.) if last_step: @@ -533,30 +623,31 @@ if val_bpb is not None: print0(f"Minimum validation bpb: {min_val_bpb:.6f}") # Log to report -from nanochat.report import get_report -get_report().log(section="Base model training", data=[ - user_config, # CLI args - { # stats about the training setup - "Number of parameters": num_params, - "Number of FLOPs per token": f"{num_flops_per_token:e}", - "Calculated number of iterations": num_iterations, - "Number of training tokens": total_tokens, - "Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params, - "DDP world size": ddp_world_size, - "warmup_ratio": args.warmup_ratio, - "warmdown_ratio": args.warmdown_ratio, - "final_lr_frac": args.final_lr_frac, - }, - { # stats about training outcomes - "Minimum validation bpb": min_val_bpb if val_bpb is not None else None, - "Final validation bpb": val_bpb, - "CORE metric estimate": results.get("core_metric", None), - "MFU %": f"{mfu:.2f}%", - "Total training flops": f"{flops_so_far:e}", - "Total training time": f"{total_training_time/60:.2f}m", - "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB", - } -]) +if not args.smoke_test: + from nanochat.report import get_report + get_report().log(section="Base model training", data=[ + user_config, # CLI args + { # stats about the training setup + "Number of parameters": num_params, + "Number of FLOPs per token": f"{num_flops_per_token:e}", + "Calculated number of iterations": num_iterations, + "Number of training tokens": total_tokens, + "Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params, + "DDP world size": ddp_world_size, + "warmup_ratio": args.warmup_ratio, + "warmdown_ratio": args.warmdown_ratio, + "final_lr_frac": args.final_lr_frac, + }, + { # stats about training outcomes + "Minimum validation bpb": min_val_bpb if val_bpb is not None else None, + "Final validation bpb": val_bpb, + "CORE metric estimate": results.get("core_metric", None), + "MFU %": f"{mfu:.2f}%", + "Total training flops": f"{flops_so_far:e}", + "Total training time": f"{total_training_time/60:.2f}m", + "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB", + } + ]) # cleanup wandb_run.finish() # wandb run finish diff --git a/scripts/chat_cli.py b/scripts/chat_cli.py index 7de7e10..e88b1ba 100644 --- a/scripts/chat_cli.py +++ b/scripts/chat_cli.py @@ -18,7 +18,7 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back') parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation') parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter') -parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') +parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps', 'xla'], help='Device type for evaluation: cuda|cpu|mps|xla. empty => autodetect') parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) args = parser.parse_args() @@ -27,7 +27,13 @@ args = parser.parse_args() device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() +if device_type == "cuda": + autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) +elif device_type == "xla": + from torch_xla.amp import autocast as xla_autocast + autocast_ctx = xla_autocast(dtype=ptdtype) +else: + autocast_ctx = nullcontext() model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) # Special tokens for the chat state machine diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index bc15239..3325b78 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -71,8 +71,13 @@ def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_ if ddp: num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device) total_tensor = torch.tensor([total], dtype=torch.long, device=device) - dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM) - dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM) + if dist.is_initialized(): + dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM) + dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM) + else: + import torch_xla.core.xla_model as xm + num_passed_tensor = xm.all_reduce(xm.REDUCE_SUM, num_passed_tensor) + total_tensor = xm.all_reduce(xm.REDUCE_SUM, total_tensor) num_passed = num_passed_tensor.item() total = total_tensor.item() @@ -145,8 +150,13 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems if ddp: num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device) total_tensor = torch.tensor([total], dtype=torch.long, device=device) - dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM) - dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM) + if dist.is_initialized(): + dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM) + dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM) + else: + import torch_xla.core.xla_model as xm + num_passed_tensor = xm.all_reduce(xm.REDUCE_SUM, num_passed_tensor) + total_tensor = xm.all_reduce(xm.REDUCE_SUM, total_tensor) num_passed = num_passed_tensor.item() total = total_tensor.item() @@ -194,13 +204,19 @@ if __name__ == "__main__": parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate') - parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') + parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps', 'xla'], help='Device type for evaluation: cuda|cpu|mps|xla. empty => autodetect') args = parser.parse_args() device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 - autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() + if device_type == "cuda": + autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) + elif device_type == "xla": + from torch_xla.amp import autocast as xla_autocast + autocast_ctx = xla_autocast(dtype=ptdtype) + else: + autocast_ctx = nullcontext() model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) engine = Engine(model, tokenizer) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 4c81f06..d0e6b38 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -13,7 +13,7 @@ import argparse import os os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import time -import wandb +import runpy import torch from contextlib import nullcontext from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type @@ -30,13 +30,18 @@ from tasks.smoltalk import SmolTalk from tasks.customjson import CustomJSON from tasks.spellingbee import SimpleSpelling, SpellingBee +try: + import wandb +except ModuleNotFoundError: + wandb = None + # ----------------------------------------------------------------------------- # CLI arguments parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the model") # Logging parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime -parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") +parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps|xla (empty = autodetect)") parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") # Model loading parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") @@ -58,22 +63,63 @@ parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bp parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") # Output parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report") +parser.add_argument("--smoke-test", action="store_true", help="run a tiny fast configuration for debugging (overrides several training args)") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- # Compute init device_type = autodetect_device_type() if args.device_type == "" else args.device_type + +# TPU/XLA launcher: XLA uses multi-process execution via xmp.spawn rather than torchrun +# When device_type is xla and we're not already in a spawned worker, spawn and re-run this module +if device_type == "xla" and os.environ.get("NANOCHAT_XLA_SPAWNED", "0") != "1": + try: + import torch_xla.distributed.xla_multiprocessing as xmp + import torch_xla.core.xla_model as xm + except Exception as e: + raise RuntimeError("--device-type xla requested but torch_xla is not available") from e + + def _xla_worker(_index, argv): + os.environ["NANOCHAT_XLA_SPAWNED"] = "1" + import sys + sys.argv = argv + runpy.run_module("scripts.chat_sft", run_name="__main__") + + xmp.spawn(_xla_worker, args=(os.sys.argv,), nprocs=xm.xrt_world_size()) + raise SystemExit(0) + ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) + +if args.smoke_test: + args.run = "dummy" + args.dry_run = True + args.max_seq_len = 128 + args.device_batch_size = 1 + args.total_batch_size = args.device_batch_size * args.max_seq_len * ddp_world_size + args.num_iterations = 5 + args.eval_every = -1 + args.eval_tokens = args.total_batch_size + user_config = vars(args).copy() # refresh for logging + master_process = ddp_rank == 0 ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() -synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None +if device_type == "cuda": + autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) +elif device_type == "xla": + from torch_xla.amp import autocast as xla_autocast + autocast_ctx = xla_autocast(dtype=ptdtype) +else: + autocast_ctx = nullcontext() +synchronize = torch.cuda.synchronize if device_type == "cuda" else (lambda: None) +if device_type == "xla": + import torch_xla.core.xla_model as xm + synchronize = xm.mark_step get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 # wandb logging init use_dummy_wandb = args.run == "dummy" or not master_process -wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config) +wandb_run = DummyWandb() if (use_dummy_wandb or wandb is None) else wandb.init(project="nanochat-sft", name=args.run, config=user_config) # Load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step) @@ -81,7 +127,10 @@ pretrain_batch_size = meta.get("device_batch_size", None) if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size: print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?") orig_model = model -model = torch.compile(model, dynamic=False) +if device_type == "xla" and os.environ.get("NANOCHAT_XLA_TORCH_COMPILE", "0") != "1": + model = model +else: + model = torch.compile(model, dynamic=False) depth = model.config.n_layer num_flops_per_token = model.estimate_flops() tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank @@ -261,8 +310,13 @@ while True: # Synchronize last_step across all ranks to avoid hangs in the distributed setting if ddp: last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device) - dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX) - last_step = bool(last_step_tensor.item()) + if dist.is_initialized(): + dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX) + last_step = bool(last_step_tensor.item()) + else: + import torch_xla.core.xla_model as xm + last_step_tensor = xm.all_reduce(xm.REDUCE_MAX, last_step_tensor) + last_step = bool(last_step_tensor.item()) # once in a while: evaluate the val bpb (all ranks participate) if last_step or (args.eval_every > 0 and step % args.eval_every == 0): @@ -286,6 +340,12 @@ while True: if master_process and last_step and not args.dry_run: output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12 checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname) + if ddp: + if dist.is_initialized(): + dist.barrier() + elif device_type == "xla": + import torch_xla.core.xla_model as xm + xm.rendezvous("nanochat_chat_sft_checkpoint") save_checkpoint( checkpoint_dir, step, @@ -306,6 +366,12 @@ while True: "user_config": user_config, # inputs to the training script } ) + if ddp: + if dist.is_initialized(): + dist.barrier() + elif device_type == "xla": + import torch_xla.core.xla_model as xm + xm.rendezvous("nanochat_chat_sft_checkpoint_done") if last_step: break diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 66d7806..a7182ec 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -70,7 +70,7 @@ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on') parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) -parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') +parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps', 'xla'], help='Device type for evaluation: cuda|cpu|mps|xla. empty => autodetect') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') args = parser.parse_args() @@ -103,7 +103,7 @@ class WorkerPool: if device_type == "cuda": num_gpus = torch.cuda.device_count() else: - num_gpus = 1 # e.g. cpu|mps + num_gpus = 1 # e.g. cpu|mps|xla self.num_gpus = num_gpus self.workers: List[Worker] = [] self.available_workers: asyncio.Queue = asyncio.Queue() @@ -119,13 +119,23 @@ class WorkerPool: if device_type == "cuda": device = torch.device(f"cuda:{gpu_id}") print(f"Loading model on GPU {gpu_id}...") + elif device_type == "xla": + import torch_xla.core.xla_model as xm + device = xm.xla_device() + print(f"Loading model on {device_type}...") else: device = torch.device(device_type) # e.g. cpu|mps print(f"Loading model on {device_type}...") model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step) engine = Engine(model, tokenizer) - autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() + if device_type == "cuda": + autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) + elif device_type == "xla": + from torch_xla.amp import autocast as xla_autocast + autocast_ctx = xla_autocast(dtype=ptdtype) + else: + autocast_ctx = nullcontext() worker = Worker( gpu_id=gpu_id,