mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-04 18:52:36 +00:00
Add Adaptive Gradient Clipping (AGC) to pretraining
This commit is contained in:
parent
5eeb2b6ef9
commit
9841848a2f
|
|
@ -48,6 +48,7 @@ embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
|
||||||
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
||||||
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
|
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
|
||||||
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
|
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
|
||||||
|
agc_clip = 0.0 # adaptive gradient clipping threshold (0.0 = disabled)
|
||||||
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
||||||
# Evaluation
|
# Evaluation
|
||||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||||
|
|
@ -172,6 +173,19 @@ def get_muon_momentum(it):
|
||||||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||||||
return momentum
|
return momentum
|
||||||
|
|
||||||
|
# Adaptive gradient clipping (AGC)
|
||||||
|
@torch.compile
|
||||||
|
@torch.no_grad()
|
||||||
|
def clip_grad_agc(model, clip_threshold, eps=1e-3):
|
||||||
|
for p in model.parameters():
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
param_norm = p.norm().clamp(min=eps)
|
||||||
|
grad_norm = p.grad.norm().clamp(min=eps)
|
||||||
|
max_norm = param_norm * clip_threshold
|
||||||
|
if grad_norm > max_norm:
|
||||||
|
p.grad.mul_(max_norm / grad_norm)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Training loop
|
# Training loop
|
||||||
min_val_bpb = float("inf")
|
min_val_bpb = float("inf")
|
||||||
|
|
@ -272,6 +286,9 @@ for step in range(num_iterations + 1):
|
||||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||||
loss.backward()
|
loss.backward()
|
||||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||||
|
# adaptive gradient clipping (AGC)
|
||||||
|
if agc_clip > 0.0:
|
||||||
|
clip_grad_agc(orig_model, agc_clip)
|
||||||
# gradient clipping (TODO possibly expertiment with)
|
# gradient clipping (TODO possibly expertiment with)
|
||||||
if grad_clip > 0.0:
|
if grad_clip > 0.0:
|
||||||
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
|
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user