From 258796b04f844604518ef68bbe8369aa20a8b98c Mon Sep 17 00:00:00 2001 From: Asatov Oybek Date: Tue, 10 Mar 2026 00:22:34 +0300 Subject: [PATCH] fix: added interactive training prompt and fixed path issues --- nanochat/checkpoint_manager.py | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index 1af830f..a47225e 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -116,25 +116,46 @@ def build_model(checkpoint_dir, step, device, phase): def find_largest_model(checkpoints_dir): - # Check if the directory exists to prevent FileNotFoundError on Windows/Linux + #check: guide the user instead of letting them quit in frustration if not os.path.exists(checkpoints_dir): - raise FileNotFoundError(f"Directory not found: {checkpoints_dir}. You may need to train a model first.") + print(f"\n[!] Directory not found: {checkpoints_dir}") + print("Note: Training is resource-intensive and might slow down your PC.") + + # Give the user a clear choice to fix the issue immediately + choice = input("Would you like to start training now to create it? (y/n): ").lower() + + if choice == 'y': + os.makedirs(checkpoints_dir, exist_ok=True) + import subprocess + import sys + try: + # Automate the training process for better Developer Experience (DX) + print("Initiating base training...") + subprocess.run([sys.executable, "-m", "scripts.base_train", "--depth=2", "--run=auto_run"], check=True) + except Exception as e: + print(f"[-] Training session interrupted or failed: {e}") + # Don't crash the whole script if the user just stopped training + sys.exit(1) + else: + # If the user declines, provide a helpful error pointing to the solution + raise FileNotFoundError(f"Directory not found: {checkpoints_dir}. Please run the training script first as described in the README.") - # attempt to guess the model tag: take the biggest model available + # --- Original logic follows --- model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))] if not model_tags: - raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}") - # 1) normally all model tags are of the form d, try that first: + raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}. Please ensure the training process finished correctly.") + candidates = [] for model_tag in model_tags: match = re.match(r"d(\d+)", model_tag) if match: model_depth = int(match.group(1)) candidates.append((model_depth, model_tag)) + if candidates: candidates.sort(key=lambda x: x[0], reverse=True) return candidates[0][1] - # 2) if that failed, take the most recently updated model: + model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True) return model_tags[0]