From f2899a1b4a9a336a640fa3751218140607dd8812 Mon Sep 17 00:00:00 2001 From: suraj-self Date: Sun, 8 Mar 2026 16:30:53 +0530 Subject: [PATCH] Extend informative assertion message to chat_sft.py for consistency --- scripts/chat_sft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index c1adbb6..653e7fb 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -122,7 +122,7 @@ depth = model.config.n_layer num_flops_per_token = model.estimate_flops() tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -assert args.total_batch_size % world_tokens_per_fwdbwd == 0 +assert args.total_batch_size % world_tokens_per_fwdbwd == 0, f"total_batch_size ({args.total_batch_size}) must be a multiple of {world_tokens_per_fwdbwd}." grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")