printing steps count

This commit is contained in:
gpu-poor 2026-02-26 10:07:09 +00:00
parent c7ba252142
commit daf7ec9156

View File

@ -28,7 +28,8 @@ def evaluate_bpb(model, batches, steps, token_bytes):
total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
batch_iter = iter(batches)
for _ in range(steps):
for step in range(steps):
print(f"\reval {step+1}/{steps}", end="", flush=True)
x, y = next(batch_iter)
loss2d = model(x, y, loss_reduction='none') # (B, T)
loss2d = loss2d.view(-1) # flatten
@ -51,6 +52,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
num_bytes2d = token_bytes[y]
total_nats += (loss2d * (num_bytes2d > 0)).sum()
total_bytes += num_bytes2d.sum()
print() # newline after progress
# sum reduce across all ranks
world_size = dist.get_world_size() if dist.is_initialized() else 1
if world_size > 1: