fix tok/sec calculation bug when grad accum steps > 1

Fix tok/sec metrics for base_train and mid_train when gradient accumulation is not 1
This commit is contained in:
Andrej 2025-10-30 08:36:32 -07:00 committed by GitHub
commit dfc88334b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -294,7 +294,7 @@ for step in range(num_iterations + 1):
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * step / num_iterations pct_done = 100 * step / num_iterations
tok_per_sec = int(world_tokens_per_fwdbwd / dt) tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %

View File

@ -268,7 +268,7 @@ while True:
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * progress pct_done = 100 * progress
tok_per_sec = int(world_tokens_per_fwdbwd / dt) tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %