diff --git a/nanochat/common.py b/nanochat/common.py index bd14fd2..72a756a 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -264,6 +264,8 @@ def get_peak_flops(device_name: str) -> float: (["5090"], 209.5e12), (["4090"], 165.2e12), (["3090"], 71e12), + # NVIDIA Jetson (unified memory, iGPU) + (["orin"], 5.3e12), # Jetson Orin NX 8GB: ~5.3 TFLOPS BF16 (8 SMs @ 1.0GHz) ) for patterns, flops in _PEAK_FLOPS_TABLE: if all(p in name for p in patterns): diff --git a/nanochat/optim.py b/nanochat/optim.py index 42d862b..7008aa6 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -11,13 +11,21 @@ import torch import torch.distributed as dist from torch import Tensor +# Check if torch.compile works (requires Triton on CUDA) +_can_compile = True +try: + import triton # noqa: F401 +except ImportError: + _can_compile = False +_compile_decorator = torch.compile(dynamic=False, fullgraph=True) if _can_compile else (lambda fn: fn) + # ----------------------------------------------------------------------------- """ Good old AdamW optimizer, fused kernel. https://arxiv.org/abs/1711.05101 """ -@torch.compile(dynamic=False, fullgraph=True) +@_compile_decorator def adamw_step_fused( p: Tensor, # (32768, 768) - parameter tensor grad: Tensor, # (32768, 768) - gradient, same shape as p @@ -87,7 +95,7 @@ polar_express_coeffs = [ (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), ] -@torch.compile(dynamic=False, fullgraph=True) +@_compile_decorator def muon_step_fused( stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients stacked_params: Tensor, # (12, 768, 3072) - stacked parameters diff --git a/scripts/base_train.py b/scripts/base_train.py index 4bf7959..454ac41 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -244,7 +244,17 @@ def disable_fp8(model): # Compile the model orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) -model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe +# torch.compile requires Triton, which is not available on all platforms (e.g. Jetson aarch64) +_can_compile = True +try: + import triton # noqa: F401 +except ImportError: + _can_compile = False +if _can_compile and device_type == "cuda": + model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe + print0("✓ torch.compile enabled") +else: + print0("⚠ torch.compile disabled (Triton not available or non-CUDA device), running in eager mode") # ----------------------------------------------------------------------------- # Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.