mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 13:15:21 +00:00
Merge 0bf88031f3 into a445144d39
This commit is contained in:
commit
fc27203942
|
|
@ -264,6 +264,8 @@ def get_peak_flops(device_name: str) -> float:
|
|||
(["5090"], 209.5e12),
|
||||
(["4090"], 165.2e12),
|
||||
(["3090"], 71e12),
|
||||
# NVIDIA Jetson (unified memory, iGPU)
|
||||
(["orin"], 5.3e12), # Jetson Orin NX 8GB: ~5.3 TFLOPS BF16 (8 SMs @ 1.0GHz)
|
||||
)
|
||||
for patterns, flops in _PEAK_FLOPS_TABLE:
|
||||
if all(p in name for p in patterns):
|
||||
|
|
|
|||
|
|
@ -12,13 +12,21 @@ import torch.distributed as dist
|
|||
from torch import Tensor
|
||||
from nanochat.common import COMPUTE_DTYPE
|
||||
|
||||
# Check if torch.compile works (requires Triton on CUDA)
|
||||
_can_compile = True
|
||||
try:
|
||||
import triton # noqa: F401
|
||||
except ImportError:
|
||||
_can_compile = False
|
||||
_compile_decorator = torch.compile(dynamic=False, fullgraph=True) if _can_compile else (lambda fn: fn)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
"""
|
||||
Good old AdamW optimizer, fused kernel.
|
||||
https://arxiv.org/abs/1711.05101
|
||||
"""
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
@_compile_decorator
|
||||
def adamw_step_fused(
|
||||
p: Tensor, # (32768, 768) - parameter tensor
|
||||
grad: Tensor, # (32768, 768) - gradient, same shape as p
|
||||
|
|
@ -88,7 +96,7 @@ polar_express_coeffs = [
|
|||
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
||||
]
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
@_compile_decorator
|
||||
def muon_step_fused(
|
||||
stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients
|
||||
stacked_params: Tensor, # (12, 768, 3072) - stacked parameters
|
||||
|
|
|
|||
|
|
@ -243,7 +243,17 @@ def disable_fp8(model):
|
|||
# Compile the model
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
# torch.compile requires Triton, which is not available on all platforms (e.g. Jetson aarch64)
|
||||
_can_compile = True
|
||||
try:
|
||||
import triton # noqa: F401
|
||||
except ImportError:
|
||||
_can_compile = False
|
||||
if _can_compile and device_type == "cuda":
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
print0("✓ torch.compile enabled")
|
||||
else:
|
||||
print0("⚠ torch.compile disabled (Triton not available or non-CUDA device), running in eager mode")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user