From fcc4de7b96ffdfdb411244f361dde0d83f3f9fda Mon Sep 17 00:00:00 2001 From: Sushrut Karnik Date: Thu, 12 Mar 2026 23:55:05 +0100 Subject: [PATCH 1/3] Change dtypes, device, remove bfloat16 in muon, remove torch.compile, mps base.train 100min runtime --- nanochat/common.py | 6 + nanochat/engine.py | 4 +- nanochat/flash_attention.py | 6 +- nanochat/gpt.py | 2 +- nanochat/optim.py | 52 ++++---- pyproject.toml | 12 +- runs/runcpu.sh | 16 ++- scripts/base_train.py | 9 +- uv.lock | 233 ++++++------------------------------ 9 files changed, 101 insertions(+), 239 deletions(-) diff --git a/nanochat/common.py b/nanochat/common.py index bd14fd2..f7ace6e 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -27,6 +27,12 @@ def _detect_compute_dtype(): # fp16 training requires GradScaler (not yet implemented), so fall back to fp32. # Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing. return torch.float32, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (pre-Ampere, bf16 not supported, using fp32)" + + if torch.backends.mps.is_available(): + # torch.float16 + # return torch.float16, "auto-detected: mps, float16" + return torch.float32, "auto-detected: mps, float16" + return torch.float32, "auto-detected: no CUDA (CPU/MPS)" COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype() diff --git a/nanochat/engine.py b/nanochat/engine.py index 4724c8f..c96a643 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -160,6 +160,8 @@ class RowState: self.python_expr_tokens = [] # Tokens of the current python expression self.completed = False # Whether this row has completed generation +from nanochat.common import COMPUTE_DTYPE + class Engine: def __init__(self, model, tokenizer): @@ -177,7 +179,7 @@ class Engine: # As a quick hack, we're making generate() function inherit and know about this repo-wise assumption. # I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase. # In particular, the KVCache should allocate its tensors lazily - dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 + dtype = COMPUTE_DTYPE #torch.bfloat16 if device.type == "cuda" else torch.float32 rng = torch.Generator(device=device) rng.manual_seed(seed) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index af2aee3..ea0e678 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -77,7 +77,7 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): # Full context, same length if (window < 0 or window >= Tq) and Tq == Tk: - return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa) + return F.scaled_dot_product_attention(q, k, v, is_causal=True)#, enable_gqa=enable_gqa) # Single token generation if Tq == 1: @@ -86,7 +86,7 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): start = max(0, Tk - (window + 1)) k = k[:, :, start:, :] v = v[:, :, start:, :] - return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) + return F.scaled_dot_product_attention(q, k, v, is_causal=False)#, enable_gqa=enable_gqa) # Need explicit mask for sliding window/chunk inference device = q.device @@ -99,7 +99,7 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): if window >= 0 and window < Tk: mask = mask & ((row_idx - col_idx) <= window) - return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) + return F.scaled_dot_product_attention(q, k, v, attn_mask=mask)#, enable_gqa=enable_gqa) # ============================================================================= # Public API: Same interface as FA3 diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 5e99c73..fb77d77 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -40,7 +40,7 @@ class GPTConfig: def norm(x): - return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok + return F.layer_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok class Linear(nn.Linear): """nn.Linear that casts weights to match input dtype in forward. diff --git a/nanochat/optim.py b/nanochat/optim.py index 0ee2e27..42b4949 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -17,7 +17,7 @@ Good old AdamW optimizer, fused kernel. https://arxiv.org/abs/1711.05101 """ -@torch.compile(dynamic=False, fullgraph=True) +# @torch.compile(dynamic=False, fullgraph=True) def adamw_step_fused( p: Tensor, # (32768, 768) - parameter tensor grad: Tensor, # (32768, 768) - gradient, same shape as p @@ -35,6 +35,7 @@ def adamw_step_fused( All in one compiled graph to eliminate Python overhead between ops. The 0-D CPU tensors avoid recompilation when hyperparameter values change. """ + p = p.to(grad.device) # Weight decay (decoupled, applied before the update) p.mul_(1 - lr_t * wd_t) # Update running averages (lerp_ is cleaner and fuses well) @@ -87,7 +88,7 @@ polar_express_coeffs = [ (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), ] -@torch.compile(dynamic=False, fullgraph=True) +# @torch.compile(dynamic=False, fullgraph=True) def muon_step_fused( stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients stacked_params: Tensor, # (12, 768, 3072) - stacked parameters @@ -112,7 +113,7 @@ def muon_step_fused( g = stacked_grads.lerp_(momentum_buffer, momentum) # Polar express - X = g.bfloat16() + X = g.bfloat16() if torch.cuda.is_available() else g X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6) if g.size(-2) > g.size(-1): # Tall matrix for a, b, c in polar_express_coeffs[:ns_steps]: @@ -179,36 +180,39 @@ class MuonAdamW(torch.optim.Optimizer): super().__init__(param_groups, defaults={}) # 0-D CPU tensors to avoid torch.compile recompilation when values change # AdamW tensors - self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + device="mps" + self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device=device) + self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device=device) + self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device=device) + self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device=device) + self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device=device) + self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device=device) # Muon tensors - self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device=device) + self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device=device) + self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device=device) + self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device=device) def _step_adamw(self, group: dict) -> None: """ AdamW update for each param in the group individually. Lazy init the state, fill in all 0-D tensors, call the fused kernel. """ + + device = self._adamw_step_t.device for p in group['params']: if p.grad is None: continue - grad = p.grad + grad = p.grad.to(device) state = self.state[p] # State init if not state: state['step'] = 0 - state['exp_avg'] = torch.zeros_like(p) - state['exp_avg_sq'] = torch.zeros_like(p) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] + state['exp_avg'] = torch.zeros_like(p).to(device) + state['exp_avg_sq'] = torch.zeros_like(p).to(device) + exp_avg = state['exp_avg'].to(device) + exp_avg_sq = state['exp_avg_sq'].to(device) state['step'] += 1 # Fill 0-D tensors with current values @@ -244,18 +248,18 @@ class MuonAdamW(torch.optim.Optimizer): # Momentum for every individual parameter if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device) - momentum_buffer = state["momentum_buffer"] + momentum_buffer = state["momentum_buffer"].to(self._muon_momentum_t.device) # Second momentum buffer is factored, either per-row or per-column if "second_momentum_buffer" not in state: state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1]) state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device) - second_momentum_buffer = state["second_momentum_buffer"] + second_momentum_buffer = state["second_momentum_buffer"].to(self._muon_momentum_t.device) red_dim = -1 if shape[-2] >= shape[-1] else -2 # Stack grads and params (NOTE: this assumes all params have the same shape) - stacked_grads = torch.stack([p.grad for p in params]) - stacked_params = torch.stack(params) + stacked_grads = torch.stack([p.grad for p in params]).to(self._muon_momentum_t.device) + stacked_params = torch.stack(params).to(self._muon_momentum_t.device) # Fill all the 0-D tensors with current values self._muon_momentum_t.fill_(group["momentum"]) @@ -278,7 +282,9 @@ class MuonAdamW(torch.optim.Optimizer): ) # Copy back to original params - torch._foreach_copy_(params, list(stacked_params.unbind(0))) + # torch._foreach_copy_(params, list(stacked_params.unbind(0))) + for p, sp in zip(params, stacked_params.unbind(0)): + p.copy_(sp) @torch.no_grad() def step(self): diff --git a/pyproject.toml b/pyproject.toml index 8b6fd95..870ab50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,6 @@ dependencies = [ "tabulate>=0.9.0", "tiktoken>=0.11.0", "tokenizers>=0.22.0", - "torch==2.9.1", "transformers>=4.57.3", "uvicorn>=0.36.0", "wandb>=0.21.3", @@ -42,15 +41,7 @@ python_functions = ["test_*"] # target torch to cuda 12.8 or CPU [tool.uv.sources] -torch = [ - { index = "pytorch-cpu", extra = "cpu" }, - { index = "pytorch-cu128", extra = "gpu" }, -] -[[tool.uv.index]] -name = "pytorch-cpu" -url = "https://download.pytorch.org/whl/cpu" -explicit = true [[tool.uv.index]] name = "pytorch-cu128" @@ -58,8 +49,9 @@ url = "https://download.pytorch.org/whl/cu128" explicit = true [project.optional-dependencies] + cpu = [ - "torch==2.9.1", + ] gpu = [ "torch==2.9.1", diff --git a/runs/runcpu.sh b/runs/runcpu.sh index 853fa1f..74f4e93 100755 --- a/runs/runcpu.sh +++ b/runs/runcpu.sh @@ -10,21 +10,31 @@ # Think of this run as educational/fun demo, not something you should expect to work well. # You may also want to run this script manually and one by one, copy pasting commands into your terminal. -# all the setup stuff +# # all the setup stuff export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat" mkdir -p $NANOCHAT_BASE_DIR command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh [ -d ".venv" ] || uv venv +# ## Fails on macbook, instead remove torch cpu from UV file and do this here uv sync --extra cpu +uv pip install torch torchvision torchaudio --upgrade --index-url https://download.pytorch.org/whl/cpu +uv pip uninstall numpy +uv pip install --force-reinstall -v "numpy==1.25.2" + source .venv/bin/activate +uv pip install torch torchvision torchaudio --upgrade --index-url https://download.pytorch.org/whly/cpu +uv pip uninstall numpy +uv pip install --force-reinstall -v "numpy==1.25.2" + if [ -z "$WANDB_RUN" ]; then WANDB_RUN=dummy fi # train tokenizer on ~2B characters (~34 seconds on my MacBook Pro M3 Max) -python -m nanochat.dataset -n 8 -python -m scripts.tok_train --max-chars=2000000000 +# python -m nanochat.dataset -n 8 +# python -m scripts.tok_train --max-chars=2000000000 python -m scripts.tok_eval +# Target directory: /Users/sushrutkarnik_1/.cache/nanochat/base_data_climbmix # train a small 4 layer model # I tuned this run to complete in about 30 minutes on my MacBook Pro M3 Max. diff --git a/scripts/base_train.py b/scripts/base_train.py index cfbfe28..2f44196 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -242,7 +242,7 @@ def disable_fp8(model): # Compile the model orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) -model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe +# model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe # ----------------------------------------------------------------------------- # Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay. @@ -313,14 +313,17 @@ optimizer = model.setup_optimizer( matrix_lr=args.matrix_lr * batch_lr_scale, weight_decay=weight_decay_scaled, ) - +# optimizer.to(device) if resuming: optimizer.load_state_dict(optimizer_data) del optimizer_data # ----------------------------------------------------------------------------- # GradScaler for fp16 training (bf16/fp32 don't need it — bf16 has the same exponent range as fp32) -scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None +# scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None +scaler = torch.cuda.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None +# scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) + if scaler is not None: print0("GradScaler enabled for fp16 training") diff --git a/uv.lock b/uv.lock index bbc9519..cf17127 100644 --- a/uv.lock +++ b/uv.lock @@ -1505,10 +1505,6 @@ dependencies = [ { name = "tabulate" }, { name = "tiktoken" }, { name = "tokenizers" }, - { name = "torch", version = "2.9.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "torch", version = "2.9.1", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, - { name = "torch", version = "2.9.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "torch", version = "2.9.1+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-8-nanochat-gpu'" }, { name = "transformers" }, { name = "uvicorn" }, { name = "wandb" }, @@ -1516,12 +1512,8 @@ dependencies = [ ] [package.optional-dependencies] -cpu = [ - { name = "torch", version = "2.9.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "torch", version = "2.9.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, -] gpu = [ - { name = "torch", version = "2.9.1+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" } }, + { name = "torch" }, ] [package.dev-dependencies] @@ -1545,9 +1537,7 @@ requires-dist = [ { name = "tabulate", specifier = ">=0.9.0" }, { name = "tiktoken", specifier = ">=0.11.0" }, { name = "tokenizers", specifier = ">=0.22.0" }, - { name = "torch", specifier = "==2.9.1" }, - { name = "torch", marker = "extra == 'cpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "nanochat", extra = "cpu" } }, - { name = "torch", marker = "extra == 'gpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nanochat", extra = "gpu" } }, + { name = "torch", marker = "extra == 'gpu'", specifier = "==2.9.1" }, { name = "transformers", specifier = ">=4.57.3" }, { name = "uvicorn", specifier = ">=0.36.0" }, { name = "wandb", specifier = ">=0.21.3" }, @@ -1572,13 +1562,8 @@ name = "networkx" version = "3.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'", - "python_full_version < '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'", - "python_full_version < '3.11' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", - "python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", - "python_full_version < '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", - "python_full_version < '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", - "python_full_version < '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version < '3.11' and sys_platform == 'linux'", + "python_full_version < '3.11' and sys_platform != 'linux'", ] sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" } wheels = [ @@ -1590,20 +1575,10 @@ name = "networkx" version = "3.5" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'", - "python_full_version >= '3.12' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'", - "python_full_version == '3.11.*' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'", - "python_full_version == '3.11.*' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'", - "python_full_version >= '3.12' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", - "python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", - "python_full_version == '3.11.*' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", - "python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", - "python_full_version >= '3.12' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", - "python_full_version == '3.11.*' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", - "python_full_version >= '3.12' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", - "python_full_version >= '3.12' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", - "python_full_version == '3.11.*' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", - "python_full_version == '3.11.*' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version >= '3.12' and sys_platform == 'linux'", + "python_full_version >= '3.12' and sys_platform != 'linux'", + "python_full_version == '3.11.*' and sys_platform == 'linux'", + "python_full_version == '3.11.*' and sys_platform != 'linux'", ] sdist = { url = "https://files.pythonhosted.org/packages/6c/4f/ccdb8ad3a38e583f214547fd2f7ff1fc160c43a75af88e6aec213404b96a/networkx-3.5.tar.gz", hash = "sha256:d4c6f9cf81f52d69230866796b82afbccdec3db7ae4fbd1b65ea750feed50037", size = 2471065, upload-time = "2025-05-29T11:35:07.804Z" } wheels = [ @@ -1687,7 +1662,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" }, @@ -1700,7 +1675,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" }, @@ -1732,9 +1707,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cusparse-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" }, @@ -1747,7 +1722,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" }, @@ -2989,56 +2964,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257, upload-time = "2024-11-27T22:38:35.385Z" }, ] -[[package]] -name = "torch" -version = "2.9.1" -source = { registry = "https://download.pytorch.org/whl/cpu" } -resolution-markers = [ - "python_full_version >= '3.12' and sys_platform == 'darwin'", - "python_full_version == '3.11.*' and sys_platform == 'darwin'", - "python_full_version < '3.11' and sys_platform == 'darwin'", -] -dependencies = [ - { name = "filelock", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "fsspec", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "jinja2", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version >= '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.12' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "sympy", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "typing-extensions", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, -] -wheels = [ - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:bf1e68cfb935ae2046374ff02a7aa73dda70351b46342846f557055b3a540bf0" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:a52952a8c90a422c14627ea99b9826b7557203b46b4d0772d3ca5c7699692425" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:287242dd1f830846098b5eca847f817aa5c6015ea57ab4c1287809efea7b77eb" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:8924d10d36eac8fe0652a060a03fc2ae52980841850b9a1a2ddb0f27a4f181cd" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:bcee64ae7aa65876ceeae6dcaebe75109485b213528c74939602208a20706e3f" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:defadbeb055cfcf5def58f70937145aecbd7a4bc295238ded1d0e85ae2cf0e1d" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:886f84b181f766f53265ba0a1d503011e60f53fff9d569563ef94f24160e1072" }, -] - [[package]] name = "torch" version = "2.9.1" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.12' and sys_platform == 'linux'", - "python_full_version >= '3.12' and sys_platform != 'linux'", - "python_full_version == '3.11.*' and sys_platform == 'linux'", - "python_full_version < '3.11' and sys_platform == 'linux'", - "python_full_version == '3.11.*' and sys_platform != 'linux'", - "python_full_version < '3.11' and sys_platform != 'linux'", -] dependencies = [ - { name = "filelock", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, - { name = "fsspec", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, - { name = "jinja2", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, - { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "setuptools", marker = "(python_full_version >= '3.12' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "sympy", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, - { name = "typing-extensions", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, + { name = "filelock" }, + { name = "fsspec" }, + { name = "jinja2" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvshmem-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "setuptools", marker = "python_full_version >= '3.12'" }, + { name = "sympy" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "typing-extensions" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/5f/56/9577683b23072075ed2e40d725c52c2019d71a972fab8e083763da8e707e/torch-2.9.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:1cc208435f6c379f9b8fdfd5ceb5be1e3b72a6bdf1cb46c0d2812aa73472db9e", size = 104207681, upload-time = "2025-11-12T15:19:56.48Z" }, @@ -3071,117 +3025,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/db/2b/f7818f6ec88758dfd21da46b6cd46af9d1b3433e53ddbb19ad1e0da17f9b/torch-2.9.1-cp314-cp314t-win_amd64.whl", hash = "sha256:c88d3299ddeb2b35dcc31753305612db485ab6f1823e37fb29451c8b2732b87e", size = 111163659, upload-time = "2025-11-12T15:23:20.009Z" }, ] -[[package]] -name = "torch" -version = "2.9.1+cpu" -source = { registry = "https://download.pytorch.org/whl/cpu" } -resolution-markers = [ - "python_full_version >= '3.12' and sys_platform == 'linux'", - "python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux'", - "python_full_version == '3.11.*' and sys_platform == 'linux'", - "python_full_version < '3.11' and sys_platform == 'linux'", - "python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux'", - "python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux'", -] -dependencies = [ - { name = "filelock", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "fsspec", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "jinja2", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version >= '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.12' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "sympy", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "typing-extensions", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, -] -wheels = [ - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:10866c8a48c4aa5ae3f48538dc8a055b99c57d9c6af2bf5dd715374d9d6ddca3" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:7210713b66943fdbfcc237b2e782871b649123ac5d29f548ce8c85be4223ab38" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-win_amd64.whl", hash = "sha256:d6e8441453dc27524e3f1037fbf27b90a02644b84e42944b9354b4024cb51cc1" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:0e611cfb16724e62252b67d31073bc5c490cb83e92ecdc1192762535e0e44487" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:3de2adb9b4443dc9210ef1f1b16da3647ace53553166d6360bbbd7edd6f16e4d" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-win_amd64.whl", hash = "sha256:69b3785d28be5a9c56ab525788ec5000349ec59132a74b7d5e954b905015b992" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-win_arm64.whl", hash = "sha256:15b4ae6fe371d96bffb8e1e9af62164797db20a0dc1337345781659cfd0b8bb1" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3bf9b442a51a2948e41216a76d7ab00f0694cfcaaa51b6f9bcab57b7f89843e6" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:7417d8c565f219d3455654cb431c6d892a3eb40246055e14d645422de13b9ea1" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-win_amd64.whl", hash = "sha256:a4e06b4f441675d26b462123c8a83e77c55f1ec8ebc081203be2db1ea8054add" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-win_arm64.whl", hash = "sha256:1abe31f14b560c1f062699e966cb08ef5b67518a1cfac2d8547a3dbcd8387b06" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:3e532e553b37ee859205a9b2d1c7977fd6922f53bbb1b9bfdd5bdc00d1a60ed4" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:39b3dff6d8fba240ae0d1bede4ca11c2531ae3b47329206512d99e17907ff74b" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-win_amd64.whl", hash = "sha256:404a7ab2fffaf2ca069e662f331eb46313692b2f1630df2720094284f390ccef" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-win_arm64.whl", hash = "sha256:161decbff26a33f13cb5ba6d2c8f458bbf56193bcc32ecc70be6dd4c7a3ee79d" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:01b1884f724977a20c7da2f640f1c7b37f4a2c117a7f4a6c1c0424d14cb86322" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:031a597147fa81b1e6d79ccf1ad3ccc7fafa27941d6cf26ff5caaa384fb20e92" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313t-win_amd64.whl", hash = "sha256:e586ab1363e3f86aa4cc133b7fdcf98deb1d2c13d43a7a6e5a6a18e9c5364893" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:65010ab4aacce6c9a1ddfc935f986c003ca8638ded04348fd326c3e74346237c" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:88adf5157db5da1d54b1c9fe4a6c1d20ceef00e75d854e206a87dbf69e3037dc" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314-win_amd64.whl", hash = "sha256:f60e2565f261542efac07e25208fb3fc55c6fe82314a5a9cbee971edb5f27713" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:3ac2b8df2c55430e836dcda31940d47f1f5f94b8731057b6f20300ebea394dd9" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:5b688445f928f13563b7418b17c57e97bf955ab559cf73cd8f2b961f8572dbb3" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314t-win_amd64.whl", hash = "sha256:cf9c3e50b595721ca6b488bdcc326e0f1af73ed28b9b66eff504a96649bb5c96" }, -] - -[[package]] -name = "torch" -version = "2.9.1+cu128" -source = { registry = "https://download.pytorch.org/whl/cu128" } -resolution-markers = [ - "python_full_version >= '3.12' and sys_platform == 'linux'", - "python_full_version >= '3.12' and sys_platform != 'linux'", - "python_full_version == '3.11.*' and sys_platform == 'linux'", - "python_full_version < '3.11' and sys_platform == 'linux'", - "python_full_version == '3.11.*' and sys_platform != 'linux'", - "python_full_version < '3.11' and sys_platform != 'linux'", -] -dependencies = [ - { name = "filelock", marker = "extra == 'extra-8-nanochat-gpu'" }, - { name = "fsspec", marker = "extra == 'extra-8-nanochat-gpu'" }, - { name = "jinja2", marker = "extra == 'extra-8-nanochat-gpu'" }, - { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cuda-cupti-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cuda-runtime-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cudnn-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cufft-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cufile-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-curand-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cusolver-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cusparse-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cusparselt-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-nccl-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-nvshmem-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-nvtx-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "setuptools", marker = "(python_full_version >= '3.12' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "sympy", marker = "extra == 'extra-8-nanochat-gpu'" }, - { name = "triton", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "typing-extensions", marker = "extra == 'extra-8-nanochat-gpu'" }, -] -wheels = [ - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:72f0f096475e8095a6bea3fba75bd3b46cf42c761b29588f7599314e67a32661" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:c8d670aa0be6fbecd2b0e7b7d514a104dbdefcc3786ca446cf0c3415043ea40a" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp310-cp310-win_amd64.whl", hash = "sha256:64399adaa8ea0896d02cf844cba3c5dd77e769520a1af73572599e0eaa2cf551" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:cf4ad82430824a80a9f398e29369524ed26c152cf00c2c12002e5400b35e260d" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:2a1da940f0757621d098c9755f7504d791a72a40920ec85a4fd98b20253fca4e" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-win_amd64.whl", hash = "sha256:633005a3700e81b5be0df2a7d3c1d48aced23ed927653797a3bd2b144a3aeeb6" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:1176f250311fa95cc3bca8077af323e0d73ea385ba266e096af82e7e2b91f256" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:7cb4018f4ce68b61fd3ef87dc1c4ca520731c7b5b200e360ad47b612d7844063" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-win_amd64.whl", hash = "sha256:3a01f0b64c10a82d444d9fd06b3e8c567b1158b76b2764b8f51bfd8f535064b0" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:0b80b7555dcd0a75b7b06016991f01281a0bb078cf28fa2d1dfb949fad2fbd07" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:63381a109a569b280ed3319da89d3afe5cf9ab5c879936382a212affb5c90552" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-win_amd64.whl", hash = "sha256:ad9183864acdd99fc5143d7ca9d3d2e7ddfc9a9600ff43217825d4e5e9855ccc" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2314521c74d76e513c53bb72c0ce3511ef0295ff657a432790df6c207e5d7962" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:4454a4faca31af81566e3a4208f10f20b8a6d9cfe42791b0ca7ff134326468fc" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:24420e430e77136f7079354134b34e7ba9d87e539f5ac84c33b08e5c13412ebe" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:32c036296c557f19a1537ce981c40533650097114e1720a321a39a3b08d9df56" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:7788d3d03d939cf00f93ac0da5ab520846f66411e339cfbf519a806e8facf519" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-win_amd64.whl", hash = "sha256:7bcd40cbffac475b478d6ce812f03da84e9a4894956efb89c3b7bcca5dbd4f91" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:e88c78e5b08ae9303aa15da43b68b44287ecbec16d898d9fad6998832fe626a5" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:7d8769bdf3200ca16a92f14df404c3370171ac3732996528a8973d753eac562f" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-win_amd64.whl", hash = "sha256:0c784b600959ec70ee01cb23e8bc870a0e0475af30378ff5e39f4abed8b7c1cc" }, -] - [[package]] name = "tornado" version = "6.5.4" From bf19cb325cfdbd2b42ee1d9f8cd2b11e38b9afb5 Mon Sep 17 00:00:00 2001 From: Sushrut Karnik Date: Fri, 13 Mar 2026 00:07:06 +0100 Subject: [PATCH 2/3] turn tokenizer train back on --- runs/runcpu.sh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/runs/runcpu.sh b/runs/runcpu.sh index 74f4e93..b40b0bb 100755 --- a/runs/runcpu.sh +++ b/runs/runcpu.sh @@ -31,10 +31,9 @@ if [ -z "$WANDB_RUN" ]; then fi # train tokenizer on ~2B characters (~34 seconds on my MacBook Pro M3 Max) -# python -m nanochat.dataset -n 8 -# python -m scripts.tok_train --max-chars=2000000000 +python -m nanochat.dataset -n 8 +python -m scripts.tok_train --max-chars=2000000000 python -m scripts.tok_eval -# Target directory: /Users/sushrutkarnik_1/.cache/nanochat/base_data_climbmix # train a small 4 layer model # I tuned this run to complete in about 30 minutes on my MacBook Pro M3 Max. From e9bf1a5a670a63b56b8200c70159afd04f4e306a Mon Sep 17 00:00:00 2001 From: Sushrut Karnik Date: Fri, 13 Mar 2026 01:11:24 +0100 Subject: [PATCH 3/3] chat sft optim fix --- nanochat/gpt.py | 4 ++-- nanochat/optim.py | 26 +++++++++++++------------- scripts/base_train.py | 3 +++ scripts/chat_sft.py | 4 ++-- 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index fb77d77..979ff47 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -355,7 +355,7 @@ class GPT(nn.Module): 'total': total, } - def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5): + def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5, device_type="cuda"): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() @@ -390,7 +390,7 @@ class GPT(nn.Module): )) Factory = DistMuonAdamW if ddp else MuonAdamW - optimizer = Factory(param_groups) + optimizer = Factory(param_groups, device_type) for group in optimizer.param_groups: group["initial_lr"] = group["lr"] return optimizer diff --git a/nanochat/optim.py b/nanochat/optim.py index 42b4949..5f67417 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -176,11 +176,11 @@ class MuonAdamW(torch.optim.Optimizer): - For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay' - For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay' """ - def __init__(self, param_groups: list[dict]): + def __init__(self, param_groups: list[dict], device_type: str = "cpu"): super().__init__(param_groups, defaults={}) # 0-D CPU tensors to avoid torch.compile recompilation when values change # AdamW tensors - device="mps" + device=device_type self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device=device) self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device=device) self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device=device) @@ -358,19 +358,19 @@ class DistMuonAdamW(torch.optim.Optimizer): - For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay' - For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay' """ - def __init__(self, param_groups: list[dict]): + def __init__(self, param_groups: list[dict], device_type: str = "cpu"): super().__init__(param_groups, defaults={}) # 0-D CPU tensors to avoid torch.compile recompilation when values change - self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) + self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) + self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) + self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) + self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) + self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) + self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) + self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) + self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) + self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) def _reduce_adamw(self, group: dict, world_size: int) -> dict: """Launch async reduce ops for AdamW group. Returns info dict with per-param infos.""" diff --git a/scripts/base_train.py b/scripts/base_train.py index 2f44196..a9a4534 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -83,6 +83,8 @@ user_config = vars(args).copy() # for logging # Compute init and wandb logging device_type = autodetect_device_type() if args.device_type == "" else args.device_type + + ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None @@ -312,6 +314,7 @@ optimizer = model.setup_optimizer( # Muon hyperparameters matrix_lr=args.matrix_lr * batch_lr_scale, weight_decay=weight_decay_scaled, + device_type=device_type if device_type != "cuda" else "cpu" # since k keeps optim in cpu ) # optimizer.to(device) if resuming: diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index c1adbb6..c8382f9 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -117,7 +117,7 @@ for name, fallback, source in [ print0(f"Using {name}={arg_val}") orig_model = model -model = torch.compile(model, dynamic=False) +# model = torch.compile(model, dynamic=False) depth = model.config.n_layer num_flops_per_token = model.estimate_flops() tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank @@ -131,7 +131,7 @@ token_bytes = get_token_bytes(device=device) # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) # Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero -optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0) +optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0, device_type=device_type) # Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.) # Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the