mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-23 16:18:05 +00:00
Merge 758af69e25 into dc54a1a307
This commit is contained in:
commit
83c571360d
|
|
@ -198,6 +198,20 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
AdamW update for each param in the group individually.
|
||||
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
|
||||
"""
|
||||
for p in group['params']:
|
||||
if p.grad is not None:
|
||||
break
|
||||
else:
|
||||
# Don't bother with the update if no params in this group have gradients
|
||||
return
|
||||
|
||||
# Pre-fill 0-D tensors with current values
|
||||
self._adamw_lr_t.fill_(group['lr'])
|
||||
self._adamw_beta1_t.fill_(group['betas'][0])
|
||||
self._adamw_beta2_t.fill_(group['betas'][1])
|
||||
self._adamw_eps_t.fill_(group['eps'])
|
||||
self._adamw_wd_t.fill_(group['weight_decay'])
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
|
@ -215,11 +229,6 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
|
||||
# Fill 0-D tensors with current values
|
||||
self._adamw_step_t.fill_(state['step'])
|
||||
self._adamw_lr_t.fill_(group['lr'])
|
||||
self._adamw_beta1_t.fill_(group['betas'][0])
|
||||
self._adamw_beta2_t.fill_(group['betas'][1])
|
||||
self._adamw_eps_t.fill_(group['eps'])
|
||||
self._adamw_wd_t.fill_(group['weight_decay'])
|
||||
|
||||
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
||||
adamw_step_fused(
|
||||
|
|
@ -410,6 +419,14 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
def _compute_adamw(self, group: dict, info: dict, gather_list: list, rank: int, world_size: int) -> None:
|
||||
"""Wait for reduce, compute AdamW updates, launch gathers for large params."""
|
||||
param_infos = info['param_infos']
|
||||
|
||||
# Fill 0-D tensors.
|
||||
self._adamw_lr_t.fill_(group['lr'])
|
||||
self._adamw_beta1_t.fill_(group['betas'][0])
|
||||
self._adamw_beta2_t.fill_(group['betas'][1])
|
||||
self._adamw_eps_t.fill_(group['eps'])
|
||||
self._adamw_wd_t.fill_(group['weight_decay'])
|
||||
|
||||
for p in group['params']:
|
||||
pinfo = param_infos[p]
|
||||
pinfo['future'].wait()
|
||||
|
|
@ -432,11 +449,6 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
|
||||
# Fill 0-D tensors and run fused kernel
|
||||
self._adamw_step_t.fill_(state['step'])
|
||||
self._adamw_lr_t.fill_(group['lr'])
|
||||
self._adamw_beta1_t.fill_(group['betas'][0])
|
||||
self._adamw_beta2_t.fill_(group['betas'][1])
|
||||
self._adamw_eps_t.fill_(group['eps'])
|
||||
self._adamw_wd_t.fill_(group['weight_decay'])
|
||||
adamw_step_fused(
|
||||
p_slice, grad_slice, state['exp_avg'], state['exp_avg_sq'],
|
||||
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user