remove batch_num_tokens definition which was only used for experiment logging

This commit is contained in:
svlandeg 2025-12-30 11:32:47 +01:00
parent 1ca9280328
commit 32ce342c88

View File

@ -219,7 +219,6 @@ for step in range(num_iterations):
dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks
for micro_step, (train_inputs, train_targets) in enumerate(steps):
batch_num_tokens = (train_targets >= 0).sum()
with autocast_ctx:
loss = model(train_inputs, train_targets, loss_reduction='sum')
loss = loss / num_tokens # normalize loss here