diff --git a/scripts/base_train.py b/scripts/base_train.py index 4ca8cdc..fc37ac0 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -48,6 +48,7 @@ embedding_lr = 0.2 # learning rate for the embedding parameters (Adam) unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam) weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam) matrix_lr = 0.02 # learning rate for the matrix parameters (Muon) +agc_clip = 0.0 # adaptive gradient clipping threshold (0.0 = disabled) grad_clip = 1.0 # gradient clipping value (0.0 = disabled) # Evaluation eval_every = 250 # every how many steps to evaluate the model for val bpb @@ -172,6 +173,19 @@ def get_muon_momentum(it): momentum = (1 - frac) * 0.85 + frac * 0.95 return momentum +# Adaptive gradient clipping (AGC) +@torch.compile +@torch.no_grad() +def clip_grad_agc(model, clip_threshold, eps=1e-3): + for p in model.parameters(): + if p.grad is None: + continue + param_norm = p.norm().clamp(min=eps) + grad_norm = p.grad.norm().clamp(min=eps) + max_norm = param_norm * clip_threshold + if grad_norm > max_norm: + p.grad.mul_(max_norm / grad_norm) + # ----------------------------------------------------------------------------- # Training loop min_val_bpb = float("inf") @@ -272,6 +286,9 @@ for step in range(num_iterations + 1): loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss.backward() x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward + # adaptive gradient clipping (AGC) + if agc_clip > 0.0: + clip_grad_agc(orig_model, agc_clip) # gradient clipping (TODO possibly expertiment with) if grad_clip > 0.0: torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)