mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-24 13:53:27 +00:00
Replace cryptic assertion with descriptive ValueError for batch size alignment
This commit is contained in:
parent
788dadeb88
commit
0f3b6a4654
|
|
@ -387,7 +387,16 @@ else:
|
|||
# Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
# total_batch_size must be a multiple of the tokens processed in a single forward/backward pass.
|
||||
# If it isn't, gradient accumulation cannot be partitioned into equal integer steps.
|
||||
if total_batch_size % world_tokens_per_fwdbwd != 0:
|
||||
# Calculate the nearest valid multiple (rounding down, but floored at one full pass)
|
||||
suggested = max(world_tokens_per_fwdbwd, (total_batch_size // world_tokens_per_fwdbwd) * world_tokens_per_fwdbwd)
|
||||
raise ValueError(
|
||||
f"total_batch_size ({total_batch_size}) must be a multiple of "
|
||||
f"device_batch_size * max_seq_len * world_size ({world_tokens_per_fwdbwd})."
|
||||
f"Try --total-batch-size={suggested}"
|
||||
)
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user