mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-24 05:43:31 +00:00
Simplify batch size assertion message
This commit is contained in:
parent
d489a1fa22
commit
998b8f846b
|
|
@ -389,7 +389,7 @@ else:
|
|||
# Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0, f"total_batch_size ({total_batch_size}) must be a multiple of {world_tokens_per_fwdbwd}. Try {max(1, total_batch_size // world_tokens_per_fwdbwd) * world_tokens_per_fwdbwd}"
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0, f"total_batch_size ({total_batch_size}) must be a multiple of {world_tokens_per_fwdbwd}."
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user