From 0f3b6a4654c7128567f52f7493732909b8ddcd3f Mon Sep 17 00:00:00 2001 From: suraj-self Date: Sun, 15 Feb 2026 22:43:26 +0530 Subject: [PATCH 1/4] Replace cryptic assertion with descriptive ValueError for batch size alignment --- scripts/base_train.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index bb76e90..f1b5af7 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -387,7 +387,16 @@ else: # Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -assert total_batch_size % world_tokens_per_fwdbwd == 0 +# total_batch_size must be a multiple of the tokens processed in a single forward/backward pass. +# If it isn't, gradient accumulation cannot be partitioned into equal integer steps. +if total_batch_size % world_tokens_per_fwdbwd != 0: + # Calculate the nearest valid multiple (rounding down, but floored at one full pass) + suggested = max(world_tokens_per_fwdbwd, (total_batch_size // world_tokens_per_fwdbwd) * world_tokens_per_fwdbwd) + raise ValueError( + f"total_batch_size ({total_batch_size}) must be a multiple of " + f"device_batch_size * max_seq_len * world_size ({world_tokens_per_fwdbwd})." + f"Try --total-batch-size={suggested}" + ) grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") From 240a60fec2828673d95a18515064cc7dbad8ffdd Mon Sep 17 00:00:00 2001 From: suraj-self Date: Mon, 16 Feb 2026 21:20:48 +0530 Subject: [PATCH 2/4] Add informative error message to batch size assertion --- scripts/base_train.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index f1b5af7..8973663 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -387,16 +387,7 @@ else: # Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -# total_batch_size must be a multiple of the tokens processed in a single forward/backward pass. -# If it isn't, gradient accumulation cannot be partitioned into equal integer steps. -if total_batch_size % world_tokens_per_fwdbwd != 0: - # Calculate the nearest valid multiple (rounding down, but floored at one full pass) - suggested = max(world_tokens_per_fwdbwd, (total_batch_size // world_tokens_per_fwdbwd) * world_tokens_per_fwdbwd) - raise ValueError( - f"total_batch_size ({total_batch_size}) must be a multiple of " - f"device_batch_size * max_seq_len * world_size ({world_tokens_per_fwdbwd})." - f"Try --total-batch-size={suggested}" - ) +assert total_batch_size % world_tokens_per_fwdbwd == 0, f"total_batch_size ({total_batch_size}) must be a multiple of {world_tokens_per_fwdbwd}. Try {max(1, total_batch_size // world_tokens_per_fwdbwd) * world_tokens_per_fwdbwd}" grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") From 998b8f846bb126e238a88b7568e64573cbb0657f Mon Sep 17 00:00:00 2001 From: suraj-self Date: Sat, 21 Feb 2026 08:43:25 +0530 Subject: [PATCH 3/4] Simplify batch size assertion message --- scripts/base_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index 9f16723..b63fe75 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -389,7 +389,7 @@ else: # Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -assert total_batch_size % world_tokens_per_fwdbwd == 0, f"total_batch_size ({total_batch_size}) must be a multiple of {world_tokens_per_fwdbwd}. Try {max(1, total_batch_size // world_tokens_per_fwdbwd) * world_tokens_per_fwdbwd}" +assert total_batch_size % world_tokens_per_fwdbwd == 0, f"total_batch_size ({total_batch_size}) must be a multiple of {world_tokens_per_fwdbwd}." grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") From f2899a1b4a9a336a640fa3751218140607dd8812 Mon Sep 17 00:00:00 2001 From: suraj-self Date: Sun, 8 Mar 2026 16:30:53 +0530 Subject: [PATCH 4/4] Extend informative assertion message to chat_sft.py for consistency --- scripts/chat_sft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index c1adbb6..653e7fb 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -122,7 +122,7 @@ depth = model.config.n_layer num_flops_per_token = model.estimate_flops() tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -assert args.total_batch_size % world_tokens_per_fwdbwd == 0 +assert args.total_batch_size % world_tokens_per_fwdbwd == 0, f"total_batch_size ({args.total_batch_size}) must be a multiple of {world_tokens_per_fwdbwd}." grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")