feat: support non-Triton platforms (aarch64/Jetson) — conditional torch.compile

torch.compile requires Triton which is unavailable on aarch64 (NVIDIA
Jetson, ARM servers). This makes compilation conditional:

- nanochat/optim.py: Replace hardcoded @torch.compile decorators with
  _compile_decorator that becomes a no-op when Triton is not installed
- scripts/base_train.py: Wrap torch.compile(model) in try/except
- nanochat/common.py: Add Orin to peak FLOPS table (5.3 TFLOPS BF16)

Tested end-to-end on NVIDIA Jetson Orin NX 8GB (Compute 8.7, CUDA 12.6,
PyTorch 2.8.0) — pretrain and SFT both run successfully in eager mode.
No changes to behavior on platforms where Triton is available.
This commit is contained in:
Rohith Reddy 2026-03-07 16:16:33 +05:30
parent 1076f97059
commit 0bf88031f3
3 changed files with 23 additions and 3 deletions

View File

@ -264,6 +264,8 @@ def get_peak_flops(device_name: str) -> float:
(["5090"], 209.5e12),
(["4090"], 165.2e12),
(["3090"], 71e12),
# NVIDIA Jetson (unified memory, iGPU)
(["orin"], 5.3e12), # Jetson Orin NX 8GB: ~5.3 TFLOPS BF16 (8 SMs @ 1.0GHz)
)
for patterns, flops in _PEAK_FLOPS_TABLE:
if all(p in name for p in patterns):

View File

@ -11,13 +11,21 @@ import torch
import torch.distributed as dist
from torch import Tensor
# Check if torch.compile works (requires Triton on CUDA)
_can_compile = True
try:
import triton # noqa: F401
except ImportError:
_can_compile = False
_compile_decorator = torch.compile(dynamic=False, fullgraph=True) if _can_compile else (lambda fn: fn)
# -----------------------------------------------------------------------------
"""
Good old AdamW optimizer, fused kernel.
https://arxiv.org/abs/1711.05101
"""
@torch.compile(dynamic=False, fullgraph=True)
@_compile_decorator
def adamw_step_fused(
p: Tensor, # (32768, 768) - parameter tensor
grad: Tensor, # (32768, 768) - gradient, same shape as p
@ -87,7 +95,7 @@ polar_express_coeffs = [
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]
@torch.compile(dynamic=False, fullgraph=True)
@_compile_decorator
def muon_step_fused(
stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients
stacked_params: Tensor, # (12, 768, 3072) - stacked parameters

View File

@ -244,7 +244,17 @@ def disable_fp8(model):
# Compile the model
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
# torch.compile requires Triton, which is not available on all platforms (e.g. Jetson aarch64)
_can_compile = True
try:
import triton # noqa: F401
except ImportError:
_can_compile = False
if _can_compile and device_type == "cuda":
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
print0("✓ torch.compile enabled")
else:
print0("⚠ torch.compile disabled (Triton not available or non-CUDA device), running in eager mode")
# -----------------------------------------------------------------------------
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.