diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 01d9f7d..c127c94 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -249,7 +249,7 @@ while True: last_step = bool(last_step_tensor.item()) # once in a while: evaluate the val bpb (all ranks participate) - if args.eval_every > 0 and (last_step or step % args.eval_every == 0): + if last_step or (args.eval_every > 0 and step % args.eval_every == 0): model.eval() val_loader = build_val_loader() eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)