fix(miniseries): extract tokens_trained from log instead of hardcoding batch size

Same bug as scaling_laws.sh: TOKENS_TRAINED was computed as NUM_ITERS * 524288,
hardcoding the default total batch size. When base_train auto-computes a different
batch size, the value is wrong. Fix by reading "Total number of training tokens:"
directly from the training log.
This commit is contained in:
geopti 2026-02-28 20:43:34 +00:00
parent fb2be07e17
commit 16755495bc

View File

@ -85,7 +85,7 @@ for d in "${DEPTHS[@]}"; do
NUM_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | head -1 | tr -d ',')
NUM_SCALING_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP 'scaling: [\d,]+' | grep -oP '[\d,]+' | tr -d ',')
NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',')
TOKENS_TRAINED=$((NUM_ITERS * 524288))
TOKENS_TRAINED=$(grep "Total number of training tokens:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAM_DATA_RATIO=$(python -c "print(f'{$TOKENS_TRAINED / $NUM_SCALING_PARAMS:.2f}')")
MODEL_DIM=$((d * 64))
VAL_BPB=$(grep "Validation bpb:" "$LOG_FILE" | tail -1 | grep -oP '[\d.]+$')