mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-20 18:34:14 +00:00
delete grad_clip. appears to not be necessary at all. not only was it buggy because the clipping happened per gpu before grad synchronization, but it costs ~2% MFU, and it also doesn't even help. I tried deleting it a while ago and back then it did help. So I'm guessing that some hyperparameter tuning obviated the reason for it since then
This commit is contained in:
parent
e8c30c3b19
commit
061f83c152
23
dev/LOG.md
Normal file
23
dev/LOG.md
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
# Experiment Log
|
||||
|
||||
A running summary documenting some experiments and findings. Started ~Jan 7 2026.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-08: exp_grad_clip - Gradient Clipping
|
||||
|
||||
**Hypothesis:** Gradient clipping may be unnecessary overhead. Tested L2 norm clipping at various thresholds (0.25, 0.5, 1.0, 2.0) and elementwise clipping.
|
||||
|
||||
**Results:**
|
||||
- No benefit at any scale tested (d12, d20)
|
||||
- All variants within noise (~0.9827 val_bpb)
|
||||
- Grad norm never exceeds 1.0 naturally, so clipping is always inactive
|
||||
- Clipping adds ~2% time overhead from the all-reduce
|
||||
|
||||
**Bug Found:** Original implementation clipped local gradients before sync. Since this codebase doesn't use DDP (gradient sync is in the optimizers), each rank was clipping based on its own local norm. Fixed on the branch with proper distributed all-reduce.
|
||||
|
||||
**Observartion:** modded-nanogpt does not appear to clip either right now.
|
||||
|
||||
**Recommendation:** Disable by default (`--grad_clip=0.0`). The code naturally produces well-behaved gradients.
|
||||
|
||||
---
|
||||
|
|
@ -55,7 +55,6 @@ parser.add_argument("--weight_decay", type=float, default=0.0, help="weight deca
|
|||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="gradient clipping value (0.0 = disabled)")
|
||||
parser.add_argument("--warmup_ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
||||
parser.add_argument("--warmdown_ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final_lr_frac", type=float, default=0.0, help="final LR as fraction of initial LR")
|
||||
|
|
@ -346,11 +345,6 @@ while True:
|
|||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# gradient clipping
|
||||
grad_clip_enabled = args.grad_clip > 0.0
|
||||
if grad_clip_enabled:
|
||||
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), args.grad_clip)
|
||||
grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point)
|
||||
# step the optimizers
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers:
|
||||
|
|
@ -378,7 +372,6 @@ while True:
|
|||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled else ""
|
||||
# Calculate ETA based on average time per step (excluding first 10 steps)
|
||||
steps_done = step - 10
|
||||
if steps_done > 0:
|
||||
|
|
@ -388,7 +381,7 @@ while True:
|
|||
eta_str = f" | eta: {eta_seconds/60:.1f}m"
|
||||
else:
|
||||
eta_str = ""
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m{eta_str}")
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m{eta_str}")
|
||||
if step % 100 == 0:
|
||||
log_data = {
|
||||
"step": step,
|
||||
|
|
@ -400,8 +393,6 @@ while True:
|
|||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
}
|
||||
if grad_clip_enabled:
|
||||
log_data["train/grad_norm"] = grad_norm
|
||||
wandb_run.log(log_data)
|
||||
|
||||
# state update
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user