Replace cryptic assertion with descriptive ValueError for batch size alignment

This commit is contained in:
suraj-self 2026-02-15 22:43:26 +05:30
parent 788dadeb88
commit 0f3b6a4654

View File

@ -387,7 +387,16 @@ else:
# Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
# total_batch_size must be a multiple of the tokens processed in a single forward/backward pass.
# If it isn't, gradient accumulation cannot be partitioned into equal integer steps.
if total_batch_size % world_tokens_per_fwdbwd != 0:
# Calculate the nearest valid multiple (rounding down, but floored at one full pass)
suggested = max(world_tokens_per_fwdbwd, (total_batch_size // world_tokens_per_fwdbwd) * world_tokens_per_fwdbwd)
raise ValueError(
f"total_batch_size ({total_batch_size}) must be a multiple of "
f"device_batch_size * max_seq_len * world_size ({world_tokens_per_fwdbwd})."
f"Try --total-batch-size={suggested}"
)
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")