diff --git a/nanochat/loss_eval.py b/nanochat/loss_eval.py index 0100ec3..6fcbea3 100644 --- a/nanochat/loss_eval.py +++ b/nanochat/loss_eval.py @@ -59,5 +59,7 @@ def evaluate_bpb(model, batches, steps, token_bytes): # move both to cpu, calculate bpb and return total_nats = total_nats.item() total_bytes = total_bytes.item() + if total_bytes == 0: + return float('inf') bpb = total_nats / (math.log(2) * total_bytes) return bpb