fix: added interactive training prompt and fixed path issues

This commit is contained in:
Asatov Oybek 2026-03-10 00:22:34 +03:00
parent d1fae8c1d5
commit 258796b04f

View File

@ -116,25 +116,46 @@ def build_model(checkpoint_dir, step, device, phase):
def find_largest_model(checkpoints_dir):
# Check if the directory exists to prevent FileNotFoundError on Windows/Linux
#check: guide the user instead of letting them quit in frustration
if not os.path.exists(checkpoints_dir):
raise FileNotFoundError(f"Directory not found: {checkpoints_dir}. You may need to train a model first.")
print(f"\n[!] Directory not found: {checkpoints_dir}")
print("Note: Training is resource-intensive and might slow down your PC.")
# Give the user a clear choice to fix the issue immediately
choice = input("Would you like to start training now to create it? (y/n): ").lower()
if choice == 'y':
os.makedirs(checkpoints_dir, exist_ok=True)
import subprocess
import sys
try:
# Automate the training process for better Developer Experience (DX)
print("Initiating base training...")
subprocess.run([sys.executable, "-m", "scripts.base_train", "--depth=2", "--run=auto_run"], check=True)
except Exception as e:
print(f"[-] Training session interrupted or failed: {e}")
# Don't crash the whole script if the user just stopped training
sys.exit(1)
else:
# If the user declines, provide a helpful error pointing to the solution
raise FileNotFoundError(f"Directory not found: {checkpoints_dir}. Please run the training script first as described in the README.")
# attempt to guess the model tag: take the biggest model available
# --- Original logic follows ---
model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))]
if not model_tags:
raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}")
# 1) normally all model tags are of the form d<number>, try that first:
raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}. Please ensure the training process finished correctly.")
candidates = []
for model_tag in model_tags:
match = re.match(r"d(\d+)", model_tag)
if match:
model_depth = int(match.group(1))
candidates.append((model_depth, model_tag))
if candidates:
candidates.sort(key=lambda x: x[0], reverse=True)
return candidates[0][1]
# 2) if that failed, take the most recently updated model:
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True)
return model_tags[0]