mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Add Adaptive Gradient Clipping (AGC) to pretraining
This commit is contained in:
parent
5eeb2b6ef9
commit
9841848a2f
|
|
@ -48,6 +48,7 @@ embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
|
|||
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
||||
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
|
||||
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
|
||||
agc_clip = 0.0 # adaptive gradient clipping threshold (0.0 = disabled)
|
||||
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
||||
# Evaluation
|
||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||
|
|
@ -172,6 +173,19 @@ def get_muon_momentum(it):
|
|||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||||
return momentum
|
||||
|
||||
# Adaptive gradient clipping (AGC)
|
||||
@torch.compile
|
||||
@torch.no_grad()
|
||||
def clip_grad_agc(model, clip_threshold, eps=1e-3):
|
||||
for p in model.parameters():
|
||||
if p.grad is None:
|
||||
continue
|
||||
param_norm = p.norm().clamp(min=eps)
|
||||
grad_norm = p.grad.norm().clamp(min=eps)
|
||||
max_norm = param_norm * clip_threshold
|
||||
if grad_norm > max_norm:
|
||||
p.grad.mul_(max_norm / grad_norm)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
min_val_bpb = float("inf")
|
||||
|
|
@ -272,6 +286,9 @@ for step in range(num_iterations + 1):
|
|||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# adaptive gradient clipping (AGC)
|
||||
if agc_clip > 0.0:
|
||||
clip_grad_agc(orig_model, agc_clip)
|
||||
# gradient clipping (TODO possibly expertiment with)
|
||||
if grad_clip > 0.0:
|
||||
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user