mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-07 01:40:30 +00:00
printing steps count
This commit is contained in:
parent
c7ba252142
commit
daf7ec9156
|
|
@ -28,7 +28,8 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
|||
total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
|
||||
total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
|
||||
batch_iter = iter(batches)
|
||||
for _ in range(steps):
|
||||
for step in range(steps):
|
||||
print(f"\reval {step+1}/{steps}", end="", flush=True)
|
||||
x, y = next(batch_iter)
|
||||
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
||||
loss2d = loss2d.view(-1) # flatten
|
||||
|
|
@ -51,6 +52,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
|||
num_bytes2d = token_bytes[y]
|
||||
total_nats += (loss2d * (num_bytes2d > 0)).sum()
|
||||
total_bytes += num_bytes2d.sum()
|
||||
print() # newline after progress
|
||||
# sum reduce across all ranks
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
if world_size > 1:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user