diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 4b67b62..77bb6d2 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -45,8 +45,8 @@ from pydantic import BaseModel from typing import List, Optional, AsyncGenerator from dataclasses import dataclass from contextlib import nullcontext -from nanochat.common import compute_init, autodetect_device_type -from nanochat.checkpoint_manager import load_model +from nanochat.common import compute_init, autodetect_device_type, get_base_dir +from nanochat.checkpoint_manager import load_model, find_largest_model, find_last_step from nanochat.engine import Engine # Abuse prevention limits @@ -224,6 +224,34 @@ def validate_chat_request(request: ChatRequest): async def lifespan(app: FastAPI): """Load models on all GPUs on startup.""" print("Loading nanochat models across GPUs...") + + # Determine which model tag/step will be loaded (if possible) and store + # it on the app state so HTTP endpoints can report it. + try: + base_dir = get_base_dir() + model_dir = { + "base": "base_checkpoints", + "mid": "mid_checkpoints", + "sft": "chatsft_checkpoints", + "rl": "chatrl_checkpoints", + }[args.source] + checkpoints_dir = os.path.join(base_dir, model_dir) + # If user supplied explicit tag/step via CLI, prefer those + if args.model_tag is not None: + chosen_tag = args.model_tag + else: + chosen_tag = find_largest_model(checkpoints_dir) + if args.step is not None: + chosen_step = args.step + else: + chosen_step = find_last_step(os.path.join(checkpoints_dir, chosen_tag)) + app.state.model_tag = chosen_tag + app.state.model_step = chosen_step + except Exception: + # If anything goes wrong (no checkpoints, etc.) just leave fields unset + app.state.model_tag = None + app.state.model_step = None + app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus) await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step) print(f"Server ready at http://localhost:{args.port}") @@ -400,6 +428,8 @@ async def stats(): "total_workers": len(worker_pool.workers), "available_workers": worker_pool.available_workers.qsize(), "busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(), + "model_tag": getattr(app.state, "model_tag", None), + "model_step": getattr(app.state, "model_step", None), "workers": [ { "gpu_id": w.gpu_id,