This commit is contained in:
rohithreddy1095 2026-03-27 22:10:04 +00:00 committed by GitHub
commit fc27203942
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 23 additions and 3 deletions

View File

@ -264,6 +264,8 @@ def get_peak_flops(device_name: str) -> float:
(["5090"], 209.5e12),
(["4090"], 165.2e12),
(["3090"], 71e12),
# NVIDIA Jetson (unified memory, iGPU)
(["orin"], 5.3e12), # Jetson Orin NX 8GB: ~5.3 TFLOPS BF16 (8 SMs @ 1.0GHz)
)
for patterns, flops in _PEAK_FLOPS_TABLE:
if all(p in name for p in patterns):

View File

@ -12,13 +12,21 @@ import torch.distributed as dist
from torch import Tensor
from nanochat.common import COMPUTE_DTYPE
# Check if torch.compile works (requires Triton on CUDA)
_can_compile = True
try:
import triton # noqa: F401
except ImportError:
_can_compile = False
_compile_decorator = torch.compile(dynamic=False, fullgraph=True) if _can_compile else (lambda fn: fn)
# -----------------------------------------------------------------------------
"""
Good old AdamW optimizer, fused kernel.
https://arxiv.org/abs/1711.05101
"""
@torch.compile(dynamic=False, fullgraph=True)
@_compile_decorator
def adamw_step_fused(
p: Tensor, # (32768, 768) - parameter tensor
grad: Tensor, # (32768, 768) - gradient, same shape as p
@ -88,7 +96,7 @@ polar_express_coeffs = [
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]
@torch.compile(dynamic=False, fullgraph=True)
@_compile_decorator
def muon_step_fused(
stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients
stacked_params: Tensor, # (12, 768, 3072) - stacked parameters

View File

@ -243,7 +243,17 @@ def disable_fp8(model):
# Compile the model
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
# torch.compile requires Triton, which is not available on all platforms (e.g. Jetson aarch64)
_can_compile = True
try:
import triton # noqa: F401
except ImportError:
_can_compile = False
if _can_compile and device_type == "cuda":
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
print0("✓ torch.compile enabled")
else:
print0("⚠ torch.compile disabled (Triton not available or non-CUDA device), running in eager mode")
# -----------------------------------------------------------------------------
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.