diff --git a/scripts/base_train.py b/scripts/base_train.py index 4bf7959..7d3fecf 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -80,6 +80,7 @@ parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints # Output parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") args = parser.parse_args() +assert args.target_param_data_ratio > 0 or args.target_param_data_ratio == -1, "target-param-data-ratio must be positive (or -1 to disable)" user_config = vars(args).copy() # for logging # ----------------------------------------------------------------------------- # Compute init and wandb logging