mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-06 07:35:32 +00:00
Merge c96dfaa97a into 1144d186ed
This commit is contained in:
commit
c9e377f7a4
|
|
@ -96,6 +96,12 @@ def download_file_with_lock(url, filename, postprocess_fn=None):
|
|||
|
||||
def print0(s="",**kwargs):
|
||||
ddp_rank = int(os.environ.get('RANK', 0))
|
||||
if 'RANK' not in os.environ:
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm
|
||||
ddp_rank = int(xm.get_ordinal())
|
||||
except Exception:
|
||||
ddp_rank = 0
|
||||
if ddp_rank == 0:
|
||||
print(s, **kwargs)
|
||||
|
||||
|
|
@ -136,8 +142,20 @@ def get_dist_info():
|
|||
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
||||
ddp_world_size = int(os.environ['WORLD_SIZE'])
|
||||
return True, ddp_rank, ddp_local_rank, ddp_world_size
|
||||
else:
|
||||
return False, 0, 0, 1
|
||||
|
||||
# TPU/XLA uses torch_xla multiprocessing rather than torchrun.
|
||||
# If we're in an XLA runtime, expose rank/world size so dataloaders shard correctly.
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm
|
||||
ddp_world_size = int(xm.xrt_world_size())
|
||||
if ddp_world_size > 1:
|
||||
ddp_rank = int(xm.get_ordinal())
|
||||
ddp_local_rank = int(xm.get_local_ordinal())
|
||||
return True, ddp_rank, ddp_local_rank, ddp_world_size
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False, 0, 0, 1
|
||||
|
||||
def autodetect_device_type():
|
||||
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
||||
|
|
@ -150,14 +168,19 @@ def autodetect_device_type():
|
|||
print0(f"Autodetected device type: {device_type}")
|
||||
return device_type
|
||||
|
||||
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||
def compute_init(device_type="cuda"): # cuda|cpu|mps|xla
|
||||
"""Basic initialization that we keep doing over and over, so make common."""
|
||||
|
||||
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
|
||||
assert device_type in ["cuda", "mps", "cpu", "xla"], "Invalid device type atm"
|
||||
if device_type == "cuda":
|
||||
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
|
||||
if device_type == "mps":
|
||||
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
||||
if device_type == "xla":
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm
|
||||
except Exception as e:
|
||||
raise AssertionError("device_type is 'xla' but torch_xla is not available") from e
|
||||
|
||||
# Reproducibility
|
||||
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
|
||||
|
|
@ -179,6 +202,9 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
|
|||
torch.cuda.set_device(device) # make "cuda" default to this device
|
||||
dist.init_process_group(backend="nccl", device_id=device)
|
||||
dist.barrier()
|
||||
elif device_type == "xla":
|
||||
import torch_xla.core.xla_model as xm
|
||||
device = xm.xla_device()
|
||||
else:
|
||||
device = torch.device(device_type) # mps|cpu
|
||||
|
||||
|
|
|
|||
|
|
@ -109,14 +109,21 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
|||
|
||||
# Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
|
||||
# This gives us contiguous views and a single HtoD transfer
|
||||
use_cuda = device == "cuda"
|
||||
device_str = str(device)
|
||||
use_cuda = device == "cuda" or device_str == "cuda"
|
||||
use_xla = device_str.startswith("xla")
|
||||
row_buffer = torch.empty((B, row_capacity), dtype=torch.long) # for building rows without creating Python lists
|
||||
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU)
|
||||
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer
|
||||
cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience
|
||||
cpu_targets = cpu_buffer[B * T:].view(B, T)
|
||||
inputs = gpu_buffer[:B * T].view(B, T)
|
||||
targets = gpu_buffer[B * T:].view(B, T)
|
||||
if use_xla:
|
||||
import torch_xla.core.xla_model as xm
|
||||
inputs = None
|
||||
targets = None
|
||||
else:
|
||||
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer
|
||||
inputs = gpu_buffer[:B * T].view(B, T)
|
||||
targets = gpu_buffer[B * T:].view(B, T)
|
||||
|
||||
while True:
|
||||
for row_idx in range(B):
|
||||
|
|
@ -156,8 +163,14 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
|||
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
||||
|
||||
# Single HtoD copy into persistent GPU buffer and yield
|
||||
gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda)
|
||||
yield inputs, targets, state_dict
|
||||
if use_xla:
|
||||
gpu_buffer = xm.send_cpu_data_to_device(cpu_buffer, device)
|
||||
inputs = gpu_buffer[:B * T].view(B, T)
|
||||
targets = gpu_buffer[B * T:].view(B, T)
|
||||
yield inputs, targets, state_dict
|
||||
else:
|
||||
gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda)
|
||||
yield inputs, targets, state_dict
|
||||
|
||||
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
|
||||
"""Helper that omits state_dict from yields."""
|
||||
|
|
|
|||
|
|
@ -253,7 +253,10 @@ class GPT(nn.Module):
|
|||
# calculate the rotation frequencies at each (time, channel) pair
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
cos, sin = freqs.cos(), freqs.sin()
|
||||
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
||||
if device.type == "mps":
|
||||
cos, sin = cos.float(), sin.float() # avoid bf16/fp32 mixing issues on MPSGraph
|
||||
else:
|
||||
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
||||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||
return cos, sin
|
||||
|
||||
|
|
@ -391,7 +394,10 @@ class GPT(nn.Module):
|
|||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
||||
if self.cos.device.type == "mps":
|
||||
assert self.cos.dtype == torch.float32, "Rotary embeddings must be float32 on MPS"
|
||||
else:
|
||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
||||
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
||||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ A number of functions that help with evaluating a base model.
|
|||
import math
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from nanochat.common import get_dist_info
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_bpb(model, batches, steps, token_bytes):
|
||||
|
|
@ -52,10 +53,16 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
|||
total_nats += (loss2d * (num_bytes2d > 0)).sum()
|
||||
total_bytes += num_bytes2d.sum()
|
||||
# sum reduce across all ranks
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
_, _, _, world_size = get_dist_info()
|
||||
if world_size > 1:
|
||||
dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
|
||||
else:
|
||||
# TPU/XLA path (multi-process without torch.distributed process group)
|
||||
import torch_xla.core.xla_model as xm
|
||||
total_nats = xm.all_reduce(xm.REDUCE_SUM, total_nats)
|
||||
total_bytes = xm.all_reduce(xm.REDUCE_SUM, total_bytes)
|
||||
# move both to cpu, calculate bpb and return
|
||||
total_nats = total_nats.item()
|
||||
total_bytes = total_bytes.item()
|
||||
|
|
|
|||
|
|
@ -48,6 +48,34 @@ def adamw_step_fused(
|
|||
step_size = lr_t / bias1
|
||||
p.add_(exp_avg / denom, alpha=-step_size)
|
||||
|
||||
def adamw_step_eager(
|
||||
p: Tensor,
|
||||
grad: Tensor,
|
||||
exp_avg: Tensor,
|
||||
exp_avg_sq: Tensor,
|
||||
step_t: Tensor,
|
||||
lr_t: Tensor,
|
||||
beta1_t: Tensor,
|
||||
beta2_t: Tensor,
|
||||
eps_t: Tensor,
|
||||
wd_t: Tensor,
|
||||
) -> None:
|
||||
# Same math as adamw_step_fused but without torch.compile.
|
||||
step_t = step_t.to(device=p.device)
|
||||
lr_t = lr_t.to(device=p.device)
|
||||
beta1_t = beta1_t.to(device=p.device)
|
||||
beta2_t = beta2_t.to(device=p.device)
|
||||
eps_t = eps_t.to(device=p.device)
|
||||
wd_t = wd_t.to(device=p.device)
|
||||
p.mul_(1 - lr_t * wd_t)
|
||||
exp_avg.lerp_(grad, 1 - beta1_t)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
||||
bias1 = 1 - beta1_t ** step_t
|
||||
bias2 = 1 - beta2_t ** step_t
|
||||
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
||||
step_size = lr_t / bias1
|
||||
p.add_(exp_avg / denom, alpha=-step_size)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
"""
|
||||
Muon optimizer adapted and simplified from modded-nanogpt.
|
||||
|
|
@ -112,7 +140,12 @@ def muon_step_fused(
|
|||
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
||||
|
||||
# Polar express
|
||||
X = g.bfloat16()
|
||||
# MPS has limited/fragile bfloat16 support and will assert if BF16 and FP32 mix
|
||||
# inside MPSGraph. Keep Muon math in FP32 on MPS for stability.
|
||||
if stacked_grads.device.type == "mps":
|
||||
X = g.float()
|
||||
else:
|
||||
X = g.bfloat16()
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
||||
if g.size(-2) > g.size(-1): # Tall matrix
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
|
|
@ -124,7 +157,7 @@ def muon_step_fused(
|
|||
A = X @ X.mT
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
g = X
|
||||
g = X.to(dtype=stacked_params.dtype)
|
||||
|
||||
# Variance reduction
|
||||
beta2 = beta2_t.to(g.dtype)
|
||||
|
|
@ -145,6 +178,58 @@ def muon_step_fused(
|
|||
mask = (g * stacked_params) >= 0
|
||||
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
||||
|
||||
def muon_step_eager(
|
||||
stacked_grads: Tensor,
|
||||
stacked_params: Tensor,
|
||||
momentum_buffer: Tensor,
|
||||
second_momentum_buffer: Tensor,
|
||||
momentum_t: Tensor,
|
||||
lr_t: Tensor,
|
||||
wd_t: Tensor,
|
||||
beta2_t: Tensor,
|
||||
ns_steps: int,
|
||||
red_dim: int,
|
||||
) -> None:
|
||||
# Same math as muon_step_fused but without torch.compile.
|
||||
momentum_t = momentum_t.to(device=stacked_grads.device)
|
||||
lr_t = lr_t.to(device=stacked_grads.device)
|
||||
wd_t = wd_t.to(device=stacked_grads.device)
|
||||
beta2_t = beta2_t.to(device=stacked_grads.device)
|
||||
momentum = momentum_t.to(stacked_grads.dtype)
|
||||
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
||||
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
||||
|
||||
X = g.bfloat16()
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
||||
if g.size(-2) > g.size(-1):
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
A = X.mT @ X
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + X @ B
|
||||
else:
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
A = X @ X.mT
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
g = X.to(dtype=stacked_params.dtype)
|
||||
|
||||
beta2 = beta2_t.to(g.dtype)
|
||||
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
||||
red_dim_size = g.size(red_dim)
|
||||
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
||||
v_norm = v_norm_sq.sqrt()
|
||||
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
||||
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
||||
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
||||
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
||||
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
||||
g = g * final_scale.to(g.dtype)
|
||||
|
||||
lr = lr_t.to(g.dtype)
|
||||
wd = wd_t.to(g.dtype)
|
||||
mask = (g * stacked_params) >= 0
|
||||
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Single GPU version of the MuonAdamW optimizer.
|
||||
# Used mostly for reference, debugging and testing.
|
||||
|
|
@ -220,11 +305,18 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
self._adamw_wd_t.fill_(group['weight_decay'])
|
||||
|
||||
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
||||
adamw_step_fused(
|
||||
p, grad, exp_avg, exp_avg_sq,
|
||||
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
||||
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
|
||||
)
|
||||
if p.device.type == "mps":
|
||||
adamw_step_eager(
|
||||
p, grad, exp_avg, exp_avg_sq,
|
||||
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
||||
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
|
||||
)
|
||||
else:
|
||||
adamw_step_fused(
|
||||
p, grad, exp_avg, exp_avg_sq,
|
||||
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
||||
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
|
||||
)
|
||||
|
||||
def _step_muon(self, group: dict) -> None:
|
||||
"""
|
||||
|
|
@ -264,18 +356,32 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
self._muon_wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
stacked_grads,
|
||||
stacked_params,
|
||||
momentum_buffer,
|
||||
second_momentum_buffer,
|
||||
self._muon_momentum_t,
|
||||
self._muon_lr_t,
|
||||
self._muon_wd_t,
|
||||
self._muon_beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
if device.type == "mps":
|
||||
muon_step_eager(
|
||||
stacked_grads,
|
||||
stacked_params,
|
||||
momentum_buffer,
|
||||
second_momentum_buffer,
|
||||
self._muon_momentum_t,
|
||||
self._muon_lr_t,
|
||||
self._muon_wd_t,
|
||||
self._muon_beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
else:
|
||||
muon_step_fused(
|
||||
stacked_grads,
|
||||
stacked_params,
|
||||
momentum_buffer,
|
||||
second_momentum_buffer,
|
||||
self._muon_momentum_t,
|
||||
self._muon_lr_t,
|
||||
self._muon_wd_t,
|
||||
self._muon_beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy back to original params
|
||||
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
||||
|
|
|
|||
|
|
@ -17,12 +17,15 @@ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
|||
import argparse
|
||||
import time
|
||||
from contextlib import nullcontext, contextmanager
|
||||
import runpy
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
try:
|
||||
import wandb
|
||||
except ModuleNotFoundError:
|
||||
wandb = None
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
|
|
@ -38,7 +41,7 @@ parser = argparse.ArgumentParser(description="Pretrain base model")
|
|||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps|xla (empty = autodetect)")
|
||||
# FP8 training
|
||||
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)")
|
||||
parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)")
|
||||
|
|
@ -75,16 +78,66 @@ parser.add_argument("--sample-every", type=int, default=2000, help="sample from
|
|||
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
|
||||
# Output
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
|
||||
parser.add_argument("--smoke-test", action="store_true", help="run a tiny fast configuration for debugging (overrides several training args)")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy() # for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
|
||||
# TPU/XLA launcher: XLA uses multi-process execution via xmp.spawn rather than torchrun.
|
||||
# When device_type is xla and we're not already in a spawned worker, spawn and re-run this module.
|
||||
if device_type == "xla" and os.environ.get("NANOCHAT_XLA_SPAWNED", "0") != "1":
|
||||
try:
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
import torch_xla.core.xla_model as xm
|
||||
except Exception as e:
|
||||
raise RuntimeError("--device-type xla requested but torch_xla is not available") from e
|
||||
|
||||
def _xla_worker(_index, argv):
|
||||
os.environ["NANOCHAT_XLA_SPAWNED"] = "1"
|
||||
# Ensure child sees the same argv (xmp.spawn does not preserve it by default in all envs)
|
||||
import sys
|
||||
sys.argv = argv
|
||||
runpy.run_module("scripts.base_train", run_name="__main__")
|
||||
|
||||
xmp.spawn(_xla_worker, args=(os.sys.argv,), nprocs=xm.xrt_world_size())
|
||||
raise SystemExit(0)
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
|
||||
if args.smoke_test:
|
||||
args.run = "dummy"
|
||||
args.depth = 2
|
||||
args.aspect_ratio = 32
|
||||
args.head_dim = 64
|
||||
args.max_seq_len = 128
|
||||
args.device_batch_size = 1
|
||||
args.total_batch_size = args.device_batch_size * args.max_seq_len * ddp_world_size
|
||||
args.num_iterations = 5
|
||||
args.target_flops = -1.0
|
||||
args.target_param_data_ratio = -1
|
||||
args.eval_every = -1
|
||||
args.core_metric_every = -1
|
||||
args.sample_every = -1
|
||||
args.save_every = -1
|
||||
args.resume_from_step = -1
|
||||
args.fp8 = False
|
||||
user_config = vars(args).copy() # refresh for logging
|
||||
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
if device_type == "cuda":
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
|
||||
elif device_type == "xla":
|
||||
from torch_xla.amp import autocast as xla_autocast
|
||||
autocast_ctx = xla_autocast(dtype=torch.bfloat16)
|
||||
else:
|
||||
autocast_ctx = nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else (lambda: None)
|
||||
if device_type == "xla":
|
||||
import torch_xla.core.xla_model as xm
|
||||
synchronize = xm.mark_step
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
if device_type == "cuda":
|
||||
gpu_device_name = torch.cuda.get_device_name(0)
|
||||
|
|
@ -95,7 +148,7 @@ else:
|
|||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
|
||||
wandb_run = DummyWandb() if (use_dummy_wandb or wandb is None) else wandb.init(project="nanochat", name=args.run, config=user_config)
|
||||
|
||||
# Flash Attention status
|
||||
if HAS_FA3:
|
||||
|
|
@ -110,10 +163,16 @@ else:
|
|||
print0("!" * 80)
|
||||
|
||||
# Tokenizer will be useful for evaluation, also we need the vocab size
|
||||
tokenizer = get_tokenizer()
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
print0(f"Vocab size: {vocab_size:,}")
|
||||
if args.smoke_test:
|
||||
tokenizer = None
|
||||
vocab_size = 8192
|
||||
token_bytes = torch.ones(vocab_size, dtype=torch.int64, device=device)
|
||||
print0(f"Vocab size: {vocab_size:,}")
|
||||
else:
|
||||
tokenizer = get_tokenizer()
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
print0(f"Vocab size: {vocab_size:,}")
|
||||
|
||||
# Model kwargs are derived from the desired depth of the model
|
||||
# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division
|
||||
|
|
@ -291,7 +350,12 @@ def disable_fp8(model):
|
|||
# Compile the model
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
if device_type == "xla" and os.environ.get("NANOCHAT_XLA_TORCH_COMPILE", "0") != "1":
|
||||
model = model
|
||||
elif device_type == "mps" and os.environ.get("NANOCHAT_MPS_TORCH_COMPILE", "0") != "1":
|
||||
model = model
|
||||
else:
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
||||
|
|
@ -312,9 +376,23 @@ if resuming:
|
|||
# -----------------------------------------------------------------------------
|
||||
# Initialize the DataLoaders for train/val
|
||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
|
||||
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||
if args.smoke_test:
|
||||
# Synthetic data path for local debugging: avoids requiring pyarrow + dataset parquet files.
|
||||
vocab_size_for_smoke = vocab_size
|
||||
def _synthetic_loader():
|
||||
while True:
|
||||
x = torch.randint(0, vocab_size_for_smoke, (args.device_batch_size, args.max_seq_len), device=device, dtype=torch.long)
|
||||
y = torch.randint(0, vocab_size_for_smoke, (args.device_batch_size, args.max_seq_len), device=device, dtype=torch.long)
|
||||
state_dict = {"pq_idx": 0, "rg_idx": 0, "epoch": 1}
|
||||
yield x, y, state_dict
|
||||
train_loader = _synthetic_loader()
|
||||
build_val_loader = lambda: iter(())
|
||||
x, y, dataloader_state_dict = next(train_loader)
|
||||
else:
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
|
||||
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
|
||||
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Set up hyperparameter schedulers
|
||||
|
|
@ -421,7 +499,13 @@ while True:
|
|||
model.train()
|
||||
|
||||
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
|
||||
if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0):
|
||||
if (not args.smoke_test) and (last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0)):
|
||||
if ddp:
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
elif device_type == "xla":
|
||||
import torch_xla.core.xla_model as xm
|
||||
xm.rendezvous("nanochat_base_train_checkpoint")
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
|
|
@ -443,6 +527,12 @@ while True:
|
|||
},
|
||||
rank=ddp_rank,
|
||||
)
|
||||
if ddp:
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
elif device_type == "xla":
|
||||
import torch_xla.core.xla_model as xm
|
||||
xm.rendezvous("nanochat_base_train_checkpoint_done")
|
||||
|
||||
# termination conditions (TODO: possibly also add loss explosions etc.)
|
||||
if last_step:
|
||||
|
|
@ -533,30 +623,31 @@ if val_bpb is not None:
|
|||
print0(f"Minimum validation bpb: {min_val_bpb:.6f}")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Base model training", data=[
|
||||
user_config, # CLI args
|
||||
{ # stats about the training setup
|
||||
"Number of parameters": num_params,
|
||||
"Number of FLOPs per token": f"{num_flops_per_token:e}",
|
||||
"Calculated number of iterations": num_iterations,
|
||||
"Number of training tokens": total_tokens,
|
||||
"Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params,
|
||||
"DDP world size": ddp_world_size,
|
||||
"warmup_ratio": args.warmup_ratio,
|
||||
"warmdown_ratio": args.warmdown_ratio,
|
||||
"final_lr_frac": args.final_lr_frac,
|
||||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb if val_bpb is not None else None,
|
||||
"Final validation bpb": val_bpb,
|
||||
"CORE metric estimate": results.get("core_metric", None),
|
||||
"MFU %": f"{mfu:.2f}%",
|
||||
"Total training flops": f"{flops_so_far:e}",
|
||||
"Total training time": f"{total_training_time/60:.2f}m",
|
||||
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
|
||||
}
|
||||
])
|
||||
if not args.smoke_test:
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Base model training", data=[
|
||||
user_config, # CLI args
|
||||
{ # stats about the training setup
|
||||
"Number of parameters": num_params,
|
||||
"Number of FLOPs per token": f"{num_flops_per_token:e}",
|
||||
"Calculated number of iterations": num_iterations,
|
||||
"Number of training tokens": total_tokens,
|
||||
"Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params,
|
||||
"DDP world size": ddp_world_size,
|
||||
"warmup_ratio": args.warmup_ratio,
|
||||
"warmdown_ratio": args.warmdown_ratio,
|
||||
"final_lr_frac": args.final_lr_frac,
|
||||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb if val_bpb is not None else None,
|
||||
"Final validation bpb": val_bpb,
|
||||
"CORE metric estimate": results.get("core_metric", None),
|
||||
"MFU %": f"{mfu:.2f}%",
|
||||
"Total training flops": f"{flops_so_far:e}",
|
||||
"Total training time": f"{total_training_time/60:.2f}m",
|
||||
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
|
||||
}
|
||||
])
|
||||
|
||||
# cleanup
|
||||
wandb_run.finish() # wandb run finish
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
|||
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps', 'xla'], help='Device type for evaluation: cuda|cpu|mps|xla. empty => autodetect')
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -27,7 +27,13 @@ args = parser.parse_args()
|
|||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
if device_type == "cuda":
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
||||
elif device_type == "xla":
|
||||
from torch_xla.amp import autocast as xla_autocast
|
||||
autocast_ctx = xla_autocast(dtype=ptdtype)
|
||||
else:
|
||||
autocast_ctx = nullcontext()
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
|
||||
# Special tokens for the chat state machine
|
||||
|
|
|
|||
|
|
@ -71,8 +71,13 @@ def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_
|
|||
if ddp:
|
||||
num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device)
|
||||
total_tensor = torch.tensor([total], dtype=torch.long, device=device)
|
||||
dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
|
||||
else:
|
||||
import torch_xla.core.xla_model as xm
|
||||
num_passed_tensor = xm.all_reduce(xm.REDUCE_SUM, num_passed_tensor)
|
||||
total_tensor = xm.all_reduce(xm.REDUCE_SUM, total_tensor)
|
||||
num_passed = num_passed_tensor.item()
|
||||
total = total_tensor.item()
|
||||
|
||||
|
|
@ -145,8 +150,13 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
|
|||
if ddp:
|
||||
num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device)
|
||||
total_tensor = torch.tensor([total], dtype=torch.long, device=device)
|
||||
dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
|
||||
else:
|
||||
import torch_xla.core.xla_model as xm
|
||||
num_passed_tensor = xm.all_reduce(xm.REDUCE_SUM, num_passed_tensor)
|
||||
total_tensor = xm.all_reduce(xm.REDUCE_SUM, total_tensor)
|
||||
num_passed = num_passed_tensor.item()
|
||||
total = total_tensor.item()
|
||||
|
||||
|
|
@ -194,13 +204,19 @@ if __name__ == "__main__":
|
|||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps', 'xla'], help='Device type for evaluation: cuda|cpu|mps|xla. empty => autodetect')
|
||||
args = parser.parse_args()
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
if device_type == "cuda":
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
||||
elif device_type == "xla":
|
||||
from torch_xla.amp import autocast as xla_autocast
|
||||
autocast_ctx = xla_autocast(dtype=ptdtype)
|
||||
else:
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
engine = Engine(model, tokenizer)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import argparse
|
|||
import os
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
import wandb
|
||||
import runpy
|
||||
import torch
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
|
||||
|
|
@ -30,13 +30,18 @@ from tasks.smoltalk import SmolTalk
|
|||
from tasks.customjson import CustomJSON
|
||||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
|
||||
try:
|
||||
import wandb
|
||||
except ModuleNotFoundError:
|
||||
wandb = None
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the model")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps|xla (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
|
|
@ -58,22 +63,63 @@ parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bp
|
|||
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
# Output
|
||||
parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report")
|
||||
parser.add_argument("--smoke-test", action="store_true", help="run a tiny fast configuration for debugging (overrides several training args)")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
|
||||
# TPU/XLA launcher: XLA uses multi-process execution via xmp.spawn rather than torchrun
|
||||
# When device_type is xla and we're not already in a spawned worker, spawn and re-run this module
|
||||
if device_type == "xla" and os.environ.get("NANOCHAT_XLA_SPAWNED", "0") != "1":
|
||||
try:
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
import torch_xla.core.xla_model as xm
|
||||
except Exception as e:
|
||||
raise RuntimeError("--device-type xla requested but torch_xla is not available") from e
|
||||
|
||||
def _xla_worker(_index, argv):
|
||||
os.environ["NANOCHAT_XLA_SPAWNED"] = "1"
|
||||
import sys
|
||||
sys.argv = argv
|
||||
runpy.run_module("scripts.chat_sft", run_name="__main__")
|
||||
|
||||
xmp.spawn(_xla_worker, args=(os.sys.argv,), nprocs=xm.xrt_world_size())
|
||||
raise SystemExit(0)
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
|
||||
if args.smoke_test:
|
||||
args.run = "dummy"
|
||||
args.dry_run = True
|
||||
args.max_seq_len = 128
|
||||
args.device_batch_size = 1
|
||||
args.total_batch_size = args.device_batch_size * args.max_seq_len * ddp_world_size
|
||||
args.num_iterations = 5
|
||||
args.eval_every = -1
|
||||
args.eval_tokens = args.total_batch_size
|
||||
user_config = vars(args).copy() # refresh for logging
|
||||
|
||||
master_process = ddp_rank == 0
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
if device_type == "cuda":
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
||||
elif device_type == "xla":
|
||||
from torch_xla.amp import autocast as xla_autocast
|
||||
autocast_ctx = xla_autocast(dtype=ptdtype)
|
||||
else:
|
||||
autocast_ctx = nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else (lambda: None)
|
||||
if device_type == "xla":
|
||||
import torch_xla.core.xla_model as xm
|
||||
synchronize = xm.mark_step
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
|
||||
wandb_run = DummyWandb() if (use_dummy_wandb or wandb is None) else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
|
||||
|
|
@ -81,7 +127,10 @@ pretrain_batch_size = meta.get("device_batch_size", None)
|
|||
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
if device_type == "xla" and os.environ.get("NANOCHAT_XLA_TORCH_COMPILE", "0") != "1":
|
||||
model = model
|
||||
else:
|
||||
model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
|
|
@ -261,8 +310,13 @@ while True:
|
|||
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
|
||||
if ddp:
|
||||
last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device)
|
||||
dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX)
|
||||
last_step = bool(last_step_tensor.item())
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX)
|
||||
last_step = bool(last_step_tensor.item())
|
||||
else:
|
||||
import torch_xla.core.xla_model as xm
|
||||
last_step_tensor = xm.all_reduce(xm.REDUCE_MAX, last_step_tensor)
|
||||
last_step = bool(last_step_tensor.item())
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if last_step or (args.eval_every > 0 and step % args.eval_every == 0):
|
||||
|
|
@ -286,6 +340,12 @@ while True:
|
|||
if master_process and last_step and not args.dry_run:
|
||||
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
|
||||
if ddp:
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
elif device_type == "xla":
|
||||
import torch_xla.core.xla_model as xm
|
||||
xm.rendezvous("nanochat_chat_sft_checkpoint")
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
|
|
@ -306,6 +366,12 @@ while True:
|
|||
"user_config": user_config, # inputs to the training script
|
||||
}
|
||||
)
|
||||
if ddp:
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
elif device_type == "xla":
|
||||
import torch_xla.core.xla_model as xm
|
||||
xm.rendezvous("nanochat_chat_sft_checkpoint_done")
|
||||
|
||||
if last_step:
|
||||
break
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag
|
|||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps', 'xla'], help='Device type for evaluation: cuda|cpu|mps|xla. empty => autodetect')
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -103,7 +103,7 @@ class WorkerPool:
|
|||
if device_type == "cuda":
|
||||
num_gpus = torch.cuda.device_count()
|
||||
else:
|
||||
num_gpus = 1 # e.g. cpu|mps
|
||||
num_gpus = 1 # e.g. cpu|mps|xla
|
||||
self.num_gpus = num_gpus
|
||||
self.workers: List[Worker] = []
|
||||
self.available_workers: asyncio.Queue = asyncio.Queue()
|
||||
|
|
@ -119,13 +119,23 @@ class WorkerPool:
|
|||
if device_type == "cuda":
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
print(f"Loading model on GPU {gpu_id}...")
|
||||
elif device_type == "xla":
|
||||
import torch_xla.core.xla_model as xm
|
||||
device = xm.xla_device()
|
||||
print(f"Loading model on {device_type}...")
|
||||
else:
|
||||
device = torch.device(device_type) # e.g. cpu|mps
|
||||
print(f"Loading model on {device_type}...")
|
||||
|
||||
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||
engine = Engine(model, tokenizer)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
if device_type == "cuda":
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
||||
elif device_type == "xla":
|
||||
from torch_xla.amp import autocast as xla_autocast
|
||||
autocast_ctx = xla_autocast(dtype=ptdtype)
|
||||
else:
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
worker = Worker(
|
||||
gpu_id=gpu_id,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user