From 998b8f846bb126e238a88b7568e64573cbb0657f Mon Sep 17 00:00:00 2001 From: suraj-self Date: Sat, 21 Feb 2026 08:43:25 +0530 Subject: [PATCH] Simplify batch size assertion message --- scripts/base_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index 9f16723..b63fe75 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -389,7 +389,7 @@ else: # Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -assert total_batch_size % world_tokens_per_fwdbwd == 0, f"total_batch_size ({total_batch_size}) must be a multiple of {world_tokens_per_fwdbwd}. Try {max(1, total_batch_size // world_tokens_per_fwdbwd) * world_tokens_per_fwdbwd}" +assert total_batch_size % world_tokens_per_fwdbwd == 0, f"total_batch_size ({total_batch_size}) must be a multiple of {world_tokens_per_fwdbwd}." grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")