From daf7ec9156813c651f2095b835d7982ddcfda8aa Mon Sep 17 00:00:00 2001 From: gpu-poor Date: Thu, 26 Feb 2026 10:07:09 +0000 Subject: [PATCH] printing steps count --- nanochat/loss_eval.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nanochat/loss_eval.py b/nanochat/loss_eval.py index 5a556e6..c983601 100644 --- a/nanochat/loss_eval.py +++ b/nanochat/loss_eval.py @@ -28,7 +28,8 @@ def evaluate_bpb(model, batches, steps, token_bytes): total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device()) total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device()) batch_iter = iter(batches) - for _ in range(steps): + for step in range(steps): + print(f"\reval {step+1}/{steps}", end="", flush=True) x, y = next(batch_iter) loss2d = model(x, y, loss_reduction='none') # (B, T) loss2d = loss2d.view(-1) # flatten @@ -51,6 +52,7 @@ def evaluate_bpb(model, batches, steps, token_bytes): num_bytes2d = token_bytes[y] total_nats += (loss2d * (num_bytes2d > 0)).sum() total_bytes += num_bytes2d.sum() + print() # newline after progress # sum reduce across all ranks world_size = dist.get_world_size() if dist.is_initialized() else 1 if world_size > 1: