From 0bf88031f33ec0d39dd0313fb0a4235e724bdc49 Mon Sep 17 00:00:00 2001 From: Rohith Reddy Date: Sat, 7 Mar 2026 16:16:33 +0530 Subject: [PATCH] =?UTF-8?q?feat:=20support=20non-Triton=20platforms=20(aar?= =?UTF-8?q?ch64/Jetson)=20=E2=80=94=20conditional=20torch.compile?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit torch.compile requires Triton which is unavailable on aarch64 (NVIDIA Jetson, ARM servers). This makes compilation conditional: - nanochat/optim.py: Replace hardcoded @torch.compile decorators with _compile_decorator that becomes a no-op when Triton is not installed - scripts/base_train.py: Wrap torch.compile(model) in try/except - nanochat/common.py: Add Orin to peak FLOPS table (5.3 TFLOPS BF16) Tested end-to-end on NVIDIA Jetson Orin NX 8GB (Compute 8.7, CUDA 12.6, PyTorch 2.8.0) — pretrain and SFT both run successfully in eager mode. No changes to behavior on platforms where Triton is available. --- nanochat/common.py | 2 ++ nanochat/optim.py | 12 ++++++++++-- scripts/base_train.py | 12 +++++++++++- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/nanochat/common.py b/nanochat/common.py index bd14fd2..72a756a 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -264,6 +264,8 @@ def get_peak_flops(device_name: str) -> float: (["5090"], 209.5e12), (["4090"], 165.2e12), (["3090"], 71e12), + # NVIDIA Jetson (unified memory, iGPU) + (["orin"], 5.3e12), # Jetson Orin NX 8GB: ~5.3 TFLOPS BF16 (8 SMs @ 1.0GHz) ) for patterns, flops in _PEAK_FLOPS_TABLE: if all(p in name for p in patterns): diff --git a/nanochat/optim.py b/nanochat/optim.py index 42d862b..7008aa6 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -11,13 +11,21 @@ import torch import torch.distributed as dist from torch import Tensor +# Check if torch.compile works (requires Triton on CUDA) +_can_compile = True +try: + import triton # noqa: F401 +except ImportError: + _can_compile = False +_compile_decorator = torch.compile(dynamic=False, fullgraph=True) if _can_compile else (lambda fn: fn) + # ----------------------------------------------------------------------------- """ Good old AdamW optimizer, fused kernel. https://arxiv.org/abs/1711.05101 """ -@torch.compile(dynamic=False, fullgraph=True) +@_compile_decorator def adamw_step_fused( p: Tensor, # (32768, 768) - parameter tensor grad: Tensor, # (32768, 768) - gradient, same shape as p @@ -87,7 +95,7 @@ polar_express_coeffs = [ (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), ] -@torch.compile(dynamic=False, fullgraph=True) +@_compile_decorator def muon_step_fused( stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients stacked_params: Tensor, # (12, 768, 3072) - stacked parameters diff --git a/scripts/base_train.py b/scripts/base_train.py index 4bf7959..454ac41 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -244,7 +244,17 @@ def disable_fp8(model): # Compile the model orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) -model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe +# torch.compile requires Triton, which is not available on all platforms (e.g. Jetson aarch64) +_can_compile = True +try: + import triton # noqa: F401 +except ImportError: + _can_compile = False +if _can_compile and device_type == "cuda": + model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe + print0("✓ torch.compile enabled") +else: + print0("⚠ torch.compile disabled (Triton not available or non-CUDA device), running in eager mode") # ----------------------------------------------------------------------------- # Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.