diff --git a/scripts/base_train.py b/scripts/base_train.py index bb76e90..f1b5af7 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -387,7 +387,16 @@ else: # Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -assert total_batch_size % world_tokens_per_fwdbwd == 0 +# total_batch_size must be a multiple of the tokens processed in a single forward/backward pass. +# If it isn't, gradient accumulation cannot be partitioned into equal integer steps. +if total_batch_size % world_tokens_per_fwdbwd != 0: + # Calculate the nearest valid multiple (rounding down, but floored at one full pass) + suggested = max(world_tokens_per_fwdbwd, (total_batch_size // world_tokens_per_fwdbwd) * world_tokens_per_fwdbwd) + raise ValueError( + f"total_batch_size ({total_batch_size}) must be a multiple of " + f"device_batch_size * max_seq_len * world_size ({world_tokens_per_fwdbwd})." + f"Try --total-batch-size={suggested}" + ) grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")