diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 6c2b82f..aa57724 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -188,7 +188,7 @@ while True: last_step = bool(last_step_tensor.item()) # once in a while: evaluate the val bpb (all ranks participate) - if eval_every > 0 and (last_step or step % eval_every == 0): + if last_step or (eval_every > 0 and step % eval_every == 0): model.eval() val_loader = build_val_loader() eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)