mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 21:55:14 +00:00
fix: added interactive training prompt and fixed path issues
This commit is contained in:
parent
d1fae8c1d5
commit
258796b04f
|
|
@ -116,25 +116,46 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
|
||||
|
||||
def find_largest_model(checkpoints_dir):
|
||||
# Check if the directory exists to prevent FileNotFoundError on Windows/Linux
|
||||
#check: guide the user instead of letting them quit in frustration
|
||||
if not os.path.exists(checkpoints_dir):
|
||||
raise FileNotFoundError(f"Directory not found: {checkpoints_dir}. You may need to train a model first.")
|
||||
print(f"\n[!] Directory not found: {checkpoints_dir}")
|
||||
print("Note: Training is resource-intensive and might slow down your PC.")
|
||||
|
||||
# Give the user a clear choice to fix the issue immediately
|
||||
choice = input("Would you like to start training now to create it? (y/n): ").lower()
|
||||
|
||||
if choice == 'y':
|
||||
os.makedirs(checkpoints_dir, exist_ok=True)
|
||||
import subprocess
|
||||
import sys
|
||||
try:
|
||||
# Automate the training process for better Developer Experience (DX)
|
||||
print("Initiating base training...")
|
||||
subprocess.run([sys.executable, "-m", "scripts.base_train", "--depth=2", "--run=auto_run"], check=True)
|
||||
except Exception as e:
|
||||
print(f"[-] Training session interrupted or failed: {e}")
|
||||
# Don't crash the whole script if the user just stopped training
|
||||
sys.exit(1)
|
||||
else:
|
||||
# If the user declines, provide a helpful error pointing to the solution
|
||||
raise FileNotFoundError(f"Directory not found: {checkpoints_dir}. Please run the training script first as described in the README.")
|
||||
|
||||
# attempt to guess the model tag: take the biggest model available
|
||||
# --- Original logic follows ---
|
||||
model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))]
|
||||
if not model_tags:
|
||||
raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}")
|
||||
# 1) normally all model tags are of the form d<number>, try that first:
|
||||
raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}. Please ensure the training process finished correctly.")
|
||||
|
||||
candidates = []
|
||||
for model_tag in model_tags:
|
||||
match = re.match(r"d(\d+)", model_tag)
|
||||
if match:
|
||||
model_depth = int(match.group(1))
|
||||
candidates.append((model_depth, model_tag))
|
||||
|
||||
if candidates:
|
||||
candidates.sort(key=lambda x: x[0], reverse=True)
|
||||
return candidates[0][1]
|
||||
# 2) if that failed, take the most recently updated model:
|
||||
|
||||
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True)
|
||||
return model_tags[0]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user