Add TPU/XLA support and local smoke tests

This commit is contained in:
Casey Franco 2026-02-03 18:34:45 -05:00
parent 16b8ac7da3
commit c96dfaa97a
10 changed files with 439 additions and 92 deletions

View File

@ -96,6 +96,12 @@ def download_file_with_lock(url, filename, postprocess_fn=None):
def print0(s="",**kwargs):
ddp_rank = int(os.environ.get('RANK', 0))
if 'RANK' not in os.environ:
try:
import torch_xla.core.xla_model as xm
ddp_rank = int(xm.get_ordinal())
except Exception:
ddp_rank = 0
if ddp_rank == 0:
print(s, **kwargs)
@ -136,8 +142,20 @@ def get_dist_info():
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
return True, ddp_rank, ddp_local_rank, ddp_world_size
else:
return False, 0, 0, 1
# TPU/XLA uses torch_xla multiprocessing rather than torchrun.
# If we're in an XLA runtime, expose rank/world size so dataloaders shard correctly.
try:
import torch_xla.core.xla_model as xm
ddp_world_size = int(xm.xrt_world_size())
if ddp_world_size > 1:
ddp_rank = int(xm.get_ordinal())
ddp_local_rank = int(xm.get_local_ordinal())
return True, ddp_rank, ddp_local_rank, ddp_world_size
except Exception:
pass
return False, 0, 0, 1
def autodetect_device_type():
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
@ -150,14 +168,19 @@ def autodetect_device_type():
print0(f"Autodetected device type: {device_type}")
return device_type
def compute_init(device_type="cuda"): # cuda|cpu|mps
def compute_init(device_type="cuda"): # cuda|cpu|mps|xla
"""Basic initialization that we keep doing over and over, so make common."""
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
assert device_type in ["cuda", "mps", "cpu", "xla"], "Invalid device type atm"
if device_type == "cuda":
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
if device_type == "mps":
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
if device_type == "xla":
try:
import torch_xla.core.xla_model as xm
except Exception as e:
raise AssertionError("device_type is 'xla' but torch_xla is not available") from e
# Reproducibility
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
@ -179,6 +202,9 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
elif device_type == "xla":
import torch_xla.core.xla_model as xm
device = xm.xla_device()
else:
device = torch.device(device_type) # mps|cpu

View File

@ -109,14 +109,21 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
# Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
# This gives us contiguous views and a single HtoD transfer
use_cuda = device == "cuda"
device_str = str(device)
use_cuda = device == "cuda" or device_str == "cuda"
use_xla = device_str.startswith("xla")
row_buffer = torch.empty((B, row_capacity), dtype=torch.long) # for building rows without creating Python lists
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU)
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer
cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience
cpu_targets = cpu_buffer[B * T:].view(B, T)
inputs = gpu_buffer[:B * T].view(B, T)
targets = gpu_buffer[B * T:].view(B, T)
if use_xla:
import torch_xla.core.xla_model as xm
inputs = None
targets = None
else:
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer
inputs = gpu_buffer[:B * T].view(B, T)
targets = gpu_buffer[B * T:].view(B, T)
while True:
for row_idx in range(B):
@ -156,8 +163,14 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
# Single HtoD copy into persistent GPU buffer and yield
gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda)
yield inputs, targets, state_dict
if use_xla:
gpu_buffer = xm.send_cpu_data_to_device(cpu_buffer, device)
inputs = gpu_buffer[:B * T].view(B, T)
targets = gpu_buffer[B * T:].view(B, T)
yield inputs, targets, state_dict
else:
gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda)
yield inputs, targets, state_dict
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
"""Helper that omits state_dict from yields."""

View File

@ -253,7 +253,10 @@ class GPT(nn.Module):
# calculate the rotation frequencies at each (time, channel) pair
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
if device.type == "mps":
cos, sin = cos.float(), sin.float() # avoid bf16/fp32 mixing issues on MPSGraph
else:
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
return cos, sin
@ -391,7 +394,10 @@ class GPT(nn.Module):
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
if self.cos.device.type == "mps":
assert self.cos.dtype == torch.float32, "Rotary embeddings must be float32 on MPS"
else:
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length

View File

@ -4,6 +4,7 @@ A number of functions that help with evaluating a base model.
import math
import torch
import torch.distributed as dist
from nanochat.common import get_dist_info
@torch.no_grad()
def evaluate_bpb(model, batches, steps, token_bytes):
@ -52,10 +53,16 @@ def evaluate_bpb(model, batches, steps, token_bytes):
total_nats += (loss2d * (num_bytes2d > 0)).sum()
total_bytes += num_bytes2d.sum()
# sum reduce across all ranks
world_size = dist.get_world_size() if dist.is_initialized() else 1
_, _, _, world_size = get_dist_info()
if world_size > 1:
dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
if dist.is_initialized():
dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
else:
# TPU/XLA path (multi-process without torch.distributed process group)
import torch_xla.core.xla_model as xm
total_nats = xm.all_reduce(xm.REDUCE_SUM, total_nats)
total_bytes = xm.all_reduce(xm.REDUCE_SUM, total_bytes)
# move both to cpu, calculate bpb and return
total_nats = total_nats.item()
total_bytes = total_bytes.item()

View File

@ -48,6 +48,34 @@ def adamw_step_fused(
step_size = lr_t / bias1
p.add_(exp_avg / denom, alpha=-step_size)
def adamw_step_eager(
p: Tensor,
grad: Tensor,
exp_avg: Tensor,
exp_avg_sq: Tensor,
step_t: Tensor,
lr_t: Tensor,
beta1_t: Tensor,
beta2_t: Tensor,
eps_t: Tensor,
wd_t: Tensor,
) -> None:
# Same math as adamw_step_fused but without torch.compile.
step_t = step_t.to(device=p.device)
lr_t = lr_t.to(device=p.device)
beta1_t = beta1_t.to(device=p.device)
beta2_t = beta2_t.to(device=p.device)
eps_t = eps_t.to(device=p.device)
wd_t = wd_t.to(device=p.device)
p.mul_(1 - lr_t * wd_t)
exp_avg.lerp_(grad, 1 - beta1_t)
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
bias1 = 1 - beta1_t ** step_t
bias2 = 1 - beta2_t ** step_t
denom = (exp_avg_sq / bias2).sqrt() + eps_t
step_size = lr_t / bias1
p.add_(exp_avg / denom, alpha=-step_size)
# -----------------------------------------------------------------------------
"""
Muon optimizer adapted and simplified from modded-nanogpt.
@ -108,7 +136,12 @@ def muon_step_fused(
g = stacked_grads.lerp_(momentum_buffer, momentum)
# Polar express
X = g.bfloat16()
# MPS has limited/fragile bfloat16 support and will assert if BF16 and FP32 mix
# inside MPSGraph. Keep Muon math in FP32 on MPS for stability.
if stacked_grads.device.type == "mps":
X = g.float()
else:
X = g.bfloat16()
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
if g.size(-2) > g.size(-1): # Tall matrix
for a, b, c in polar_express_coeffs[:ns_steps]:
@ -120,7 +153,7 @@ def muon_step_fused(
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
g = X
g = X.to(dtype=stacked_params.dtype)
# Variance reduction
beta2 = beta2_t.to(g.dtype)
@ -141,6 +174,58 @@ def muon_step_fused(
mask = (g * stacked_params) >= 0
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
def muon_step_eager(
stacked_grads: Tensor,
stacked_params: Tensor,
momentum_buffer: Tensor,
second_momentum_buffer: Tensor,
momentum_t: Tensor,
lr_t: Tensor,
wd_t: Tensor,
beta2_t: Tensor,
ns_steps: int,
red_dim: int,
) -> None:
# Same math as muon_step_fused but without torch.compile.
momentum_t = momentum_t.to(device=stacked_grads.device)
lr_t = lr_t.to(device=stacked_grads.device)
wd_t = wd_t.to(device=stacked_grads.device)
beta2_t = beta2_t.to(device=stacked_grads.device)
momentum = momentum_t.to(stacked_grads.dtype)
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
g = stacked_grads.lerp_(momentum_buffer, momentum)
X = g.bfloat16()
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
if g.size(-2) > g.size(-1):
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X.mT @ X
B = b * A + c * (A @ A)
X = a * X + X @ B
else:
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
g = X.to(dtype=stacked_params.dtype)
beta2 = beta2_t.to(g.dtype)
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = g.size(red_dim)
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
v_norm = v_norm_sq.sqrt()
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
g = g * final_scale.to(g.dtype)
lr = lr_t.to(g.dtype)
wd = wd_t.to(g.dtype)
mask = (g * stacked_params) >= 0
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
# -----------------------------------------------------------------------------
# Single GPU version of the MuonAdamW optimizer.
# Used mostly for reference, debugging and testing.
@ -216,11 +301,18 @@ class MuonAdamW(torch.optim.Optimizer):
self._adamw_wd_t.fill_(group['weight_decay'])
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
adamw_step_fused(
p, grad, exp_avg, exp_avg_sq,
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
)
if p.device.type == "mps":
adamw_step_eager(
p, grad, exp_avg, exp_avg_sq,
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
)
else:
adamw_step_fused(
p, grad, exp_avg, exp_avg_sq,
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
)
def _step_muon(self, group: dict) -> None:
"""
@ -260,18 +352,32 @@ class MuonAdamW(torch.optim.Optimizer):
self._muon_wd_t.fill_(group["weight_decay"])
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
muon_step_fused(
stacked_grads,
stacked_params,
momentum_buffer,
second_momentum_buffer,
self._muon_momentum_t,
self._muon_lr_t,
self._muon_wd_t,
self._muon_beta2_t,
group["ns_steps"],
red_dim,
)
if device.type == "mps":
muon_step_eager(
stacked_grads,
stacked_params,
momentum_buffer,
second_momentum_buffer,
self._muon_momentum_t,
self._muon_lr_t,
self._muon_wd_t,
self._muon_beta2_t,
group["ns_steps"],
red_dim,
)
else:
muon_step_fused(
stacked_grads,
stacked_params,
momentum_buffer,
second_momentum_buffer,
self._muon_momentum_t,
self._muon_lr_t,
self._muon_wd_t,
self._muon_beta2_t,
group["ns_steps"],
red_dim,
)
# Copy back to original params
torch._foreach_copy_(params, list(stacked_params.unbind(0)))

View File

@ -17,12 +17,15 @@ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import argparse
import time
from contextlib import nullcontext, contextmanager
import runpy
import wandb
import torch
try:
import wandb
except ModuleNotFoundError:
wandb = None
from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
@ -38,7 +41,7 @@ parser = argparse.ArgumentParser(description="Pretrain base model")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps|xla (empty = autodetect)")
# FP8 training
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)")
parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)")
@ -75,16 +78,66 @@ parser.add_argument("--sample-every", type=int, default=2000, help="sample from
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
# Output
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
parser.add_argument("--smoke-test", action="store_true", help="run a tiny fast configuration for debugging (overrides several training args)")
args = parser.parse_args()
user_config = vars(args).copy() # for logging
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
# TPU/XLA launcher: XLA uses multi-process execution via xmp.spawn rather than torchrun.
# When device_type is xla and we're not already in a spawned worker, spawn and re-run this module.
if device_type == "xla" and os.environ.get("NANOCHAT_XLA_SPAWNED", "0") != "1":
try:
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.core.xla_model as xm
except Exception as e:
raise RuntimeError("--device-type xla requested but torch_xla is not available") from e
def _xla_worker(_index, argv):
os.environ["NANOCHAT_XLA_SPAWNED"] = "1"
# Ensure child sees the same argv (xmp.spawn does not preserve it by default in all envs)
import sys
sys.argv = argv
runpy.run_module("scripts.base_train", run_name="__main__")
xmp.spawn(_xla_worker, args=(os.sys.argv,), nprocs=xm.xrt_world_size())
raise SystemExit(0)
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
if args.smoke_test:
args.run = "dummy"
args.depth = 2
args.aspect_ratio = 32
args.head_dim = 64
args.max_seq_len = 128
args.device_batch_size = 1
args.total_batch_size = args.device_batch_size * args.max_seq_len * ddp_world_size
args.num_iterations = 5
args.target_flops = -1.0
args.target_param_data_ratio = -1
args.eval_every = -1
args.core_metric_every = -1
args.sample_every = -1
args.save_every = -1
args.resume_from_step = -1
args.fp8 = False
user_config = vars(args).copy() # refresh for logging
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
if device_type == "cuda":
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
elif device_type == "xla":
from torch_xla.amp import autocast as xla_autocast
autocast_ctx = xla_autocast(dtype=torch.bfloat16)
else:
autocast_ctx = nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else (lambda: None)
if device_type == "xla":
import torch_xla.core.xla_model as xm
synchronize = xm.mark_step
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
if device_type == "cuda":
gpu_device_name = torch.cuda.get_device_name(0)
@ -95,7 +148,7 @@ else:
# wandb logging init
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
wandb_run = DummyWandb() if (use_dummy_wandb or wandb is None) else wandb.init(project="nanochat", name=args.run, config=user_config)
# Flash Attention status
if HAS_FA3:
@ -110,10 +163,16 @@ else:
print0("!" * 80)
# Tokenizer will be useful for evaluation, also we need the vocab size
tokenizer = get_tokenizer()
token_bytes = get_token_bytes(device=device)
vocab_size = tokenizer.get_vocab_size()
print0(f"Vocab size: {vocab_size:,}")
if args.smoke_test:
tokenizer = None
vocab_size = 8192
token_bytes = torch.ones(vocab_size, dtype=torch.int64, device=device)
print0(f"Vocab size: {vocab_size:,}")
else:
tokenizer = get_tokenizer()
token_bytes = get_token_bytes(device=device)
vocab_size = tokenizer.get_vocab_size()
print0(f"Vocab size: {vocab_size:,}")
# Model kwargs are derived from the desired depth of the model
# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division
@ -291,7 +350,12 @@ def disable_fp8(model):
# Compile the model
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
if device_type == "xla" and os.environ.get("NANOCHAT_XLA_TORCH_COMPILE", "0") != "1":
model = model
elif device_type == "mps" and os.environ.get("NANOCHAT_MPS_TORCH_COMPILE", "0") != "1":
model = model
else:
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
# -----------------------------------------------------------------------------
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
@ -312,9 +376,23 @@ if resuming:
# -----------------------------------------------------------------------------
# Initialize the DataLoaders for train/val
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
if args.smoke_test:
# Synthetic data path for local debugging: avoids requiring pyarrow + dataset parquet files.
vocab_size_for_smoke = vocab_size
def _synthetic_loader():
while True:
x = torch.randint(0, vocab_size_for_smoke, (args.device_batch_size, args.max_seq_len), device=device, dtype=torch.long)
y = torch.randint(0, vocab_size_for_smoke, (args.device_batch_size, args.max_seq_len), device=device, dtype=torch.long)
state_dict = {"pq_idx": 0, "rg_idx": 0, "epoch": 1}
yield x, y, state_dict
train_loader = _synthetic_loader()
build_val_loader = lambda: iter(())
x, y, dataloader_state_dict = next(train_loader)
else:
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
# -----------------------------------------------------------------------------
# Set up hyperparameter schedulers
@ -421,7 +499,13 @@ while True:
model.train()
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0):
if (not args.smoke_test) and (last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0)):
if ddp:
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.barrier()
elif device_type == "xla":
import torch_xla.core.xla_model as xm
xm.rendezvous("nanochat_base_train_checkpoint")
save_checkpoint(
checkpoint_dir,
step,
@ -443,6 +527,12 @@ while True:
},
rank=ddp_rank,
)
if ddp:
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.barrier()
elif device_type == "xla":
import torch_xla.core.xla_model as xm
xm.rendezvous("nanochat_base_train_checkpoint_done")
# termination conditions (TODO: possibly also add loss explosions etc.)
if last_step:
@ -533,30 +623,31 @@ if val_bpb is not None:
print0(f"Minimum validation bpb: {min_val_bpb:.6f}")
# Log to report
from nanochat.report import get_report
get_report().log(section="Base model training", data=[
user_config, # CLI args
{ # stats about the training setup
"Number of parameters": num_params,
"Number of FLOPs per token": f"{num_flops_per_token:e}",
"Calculated number of iterations": num_iterations,
"Number of training tokens": total_tokens,
"Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params,
"DDP world size": ddp_world_size,
"warmup_ratio": args.warmup_ratio,
"warmdown_ratio": args.warmdown_ratio,
"final_lr_frac": args.final_lr_frac,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb if val_bpb is not None else None,
"Final validation bpb": val_bpb,
"CORE metric estimate": results.get("core_metric", None),
"MFU %": f"{mfu:.2f}%",
"Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time/60:.2f}m",
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
}
])
if not args.smoke_test:
from nanochat.report import get_report
get_report().log(section="Base model training", data=[
user_config, # CLI args
{ # stats about the training setup
"Number of parameters": num_params,
"Number of FLOPs per token": f"{num_flops_per_token:e}",
"Calculated number of iterations": num_iterations,
"Number of training tokens": total_tokens,
"Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params,
"DDP world size": ddp_world_size,
"warmup_ratio": args.warmup_ratio,
"warmdown_ratio": args.warmdown_ratio,
"final_lr_frac": args.final_lr_frac,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb if val_bpb is not None else None,
"Final validation bpb": val_bpb,
"CORE metric estimate": results.get("core_metric", None),
"MFU %": f"{mfu:.2f}%",
"Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time/60:.2f}m",
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
}
])
# cleanup
wandb_run.finish() # wandb run finish

View File

@ -18,7 +18,7 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps', 'xla'], help='Device type for evaluation: cuda|cpu|mps|xla. empty => autodetect')
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
args = parser.parse_args()
@ -27,7 +27,13 @@ args = parser.parse_args()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
if device_type == "cuda":
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
elif device_type == "xla":
from torch_xla.amp import autocast as xla_autocast
autocast_ctx = xla_autocast(dtype=ptdtype)
else:
autocast_ctx = nullcontext()
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
# Special tokens for the chat state machine

View File

@ -71,8 +71,13 @@ def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_
if ddp:
num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device)
total_tensor = torch.tensor([total], dtype=torch.long, device=device)
dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
if dist.is_initialized():
dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
else:
import torch_xla.core.xla_model as xm
num_passed_tensor = xm.all_reduce(xm.REDUCE_SUM, num_passed_tensor)
total_tensor = xm.all_reduce(xm.REDUCE_SUM, total_tensor)
num_passed = num_passed_tensor.item()
total = total_tensor.item()
@ -145,8 +150,13 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
if ddp:
num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device)
total_tensor = torch.tensor([total], dtype=torch.long, device=device)
dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
if dist.is_initialized():
dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
else:
import torch_xla.core.xla_model as xm
num_passed_tensor = xm.all_reduce(xm.REDUCE_SUM, num_passed_tensor)
total_tensor = xm.all_reduce(xm.REDUCE_SUM, total_tensor)
num_passed = num_passed_tensor.item()
total = total_tensor.item()
@ -194,13 +204,19 @@ if __name__ == "__main__":
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps', 'xla'], help='Device type for evaluation: cuda|cpu|mps|xla. empty => autodetect')
args = parser.parse_args()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
if device_type == "cuda":
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
elif device_type == "xla":
from torch_xla.amp import autocast as xla_autocast
autocast_ctx = xla_autocast(dtype=ptdtype)
else:
autocast_ctx = nullcontext()
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
engine = Engine(model, tokenizer)

View File

@ -13,7 +13,7 @@ import argparse
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import time
import wandb
import runpy
import torch
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
@ -30,13 +30,18 @@ from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
try:
import wandb
except ModuleNotFoundError:
wandb = None
# -----------------------------------------------------------------------------
# CLI arguments
parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the model")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps|xla (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
@ -58,22 +63,63 @@ parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bp
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
# Output
parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report")
parser.add_argument("--smoke-test", action="store_true", help="run a tiny fast configuration for debugging (overrides several training args)")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
# TPU/XLA launcher: XLA uses multi-process execution via xmp.spawn rather than torchrun
# When device_type is xla and we're not already in a spawned worker, spawn and re-run this module
if device_type == "xla" and os.environ.get("NANOCHAT_XLA_SPAWNED", "0") != "1":
try:
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.core.xla_model as xm
except Exception as e:
raise RuntimeError("--device-type xla requested but torch_xla is not available") from e
def _xla_worker(_index, argv):
os.environ["NANOCHAT_XLA_SPAWNED"] = "1"
import sys
sys.argv = argv
runpy.run_module("scripts.chat_sft", run_name="__main__")
xmp.spawn(_xla_worker, args=(os.sys.argv,), nprocs=xm.xrt_world_size())
raise SystemExit(0)
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
if args.smoke_test:
args.run = "dummy"
args.dry_run = True
args.max_seq_len = 128
args.device_batch_size = 1
args.total_batch_size = args.device_batch_size * args.max_seq_len * ddp_world_size
args.num_iterations = 5
args.eval_every = -1
args.eval_tokens = args.total_batch_size
user_config = vars(args).copy() # refresh for logging
master_process = ddp_rank == 0
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
if device_type == "cuda":
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
elif device_type == "xla":
from torch_xla.amp import autocast as xla_autocast
autocast_ctx = xla_autocast(dtype=ptdtype)
else:
autocast_ctx = nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else (lambda: None)
if device_type == "xla":
import torch_xla.core.xla_model as xm
synchronize = xm.mark_step
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
wandb_run = DummyWandb() if (use_dummy_wandb or wandb is None) else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
# Load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
@ -81,7 +127,10 @@ pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
orig_model = model
model = torch.compile(model, dynamic=False)
if device_type == "xla" and os.environ.get("NANOCHAT_XLA_TORCH_COMPILE", "0") != "1":
model = model
else:
model = torch.compile(model, dynamic=False)
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
@ -261,8 +310,13 @@ while True:
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
if ddp:
last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device)
dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX)
last_step = bool(last_step_tensor.item())
if dist.is_initialized():
dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX)
last_step = bool(last_step_tensor.item())
else:
import torch_xla.core.xla_model as xm
last_step_tensor = xm.all_reduce(xm.REDUCE_MAX, last_step_tensor)
last_step = bool(last_step_tensor.item())
# once in a while: evaluate the val bpb (all ranks participate)
if last_step or (args.eval_every > 0 and step % args.eval_every == 0):
@ -286,6 +340,12 @@ while True:
if master_process and last_step and not args.dry_run:
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
if ddp:
if dist.is_initialized():
dist.barrier()
elif device_type == "xla":
import torch_xla.core.xla_model as xm
xm.rendezvous("nanochat_chat_sft_checkpoint")
save_checkpoint(
checkpoint_dir,
step,
@ -306,6 +366,12 @@ while True:
"user_config": user_config, # inputs to the training script
}
)
if ddp:
if dist.is_initialized():
dist.barrier()
elif device_type == "xla":
import torch_xla.core.xla_model as xm
xm.rendezvous("nanochat_chat_sft_checkpoint_done")
if last_step:
break

View File

@ -70,7 +70,7 @@ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps', 'xla'], help='Device type for evaluation: cuda|cpu|mps|xla. empty => autodetect')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
args = parser.parse_args()
@ -103,7 +103,7 @@ class WorkerPool:
if device_type == "cuda":
num_gpus = torch.cuda.device_count()
else:
num_gpus = 1 # e.g. cpu|mps
num_gpus = 1 # e.g. cpu|mps|xla
self.num_gpus = num_gpus
self.workers: List[Worker] = []
self.available_workers: asyncio.Queue = asyncio.Queue()
@ -119,13 +119,23 @@ class WorkerPool:
if device_type == "cuda":
device = torch.device(f"cuda:{gpu_id}")
print(f"Loading model on GPU {gpu_id}...")
elif device_type == "xla":
import torch_xla.core.xla_model as xm
device = xm.xla_device()
print(f"Loading model on {device_type}...")
else:
device = torch.device(device_type) # e.g. cpu|mps
print(f"Loading model on {device_type}...")
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
if device_type == "cuda":
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
elif device_type == "xla":
from torch_xla.amp import autocast as xla_autocast
autocast_ctx = xla_autocast(dtype=ptdtype)
else:
autocast_ctx = nullcontext()
worker = Worker(
gpu_id=gpu_id,