This commit is contained in:
Michael Williams 2025-11-17 13:34:50 -08:00 committed by GitHub
commit c651148338
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -45,8 +45,8 @@ from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator
from dataclasses import dataclass
from contextlib import nullcontext
from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from nanochat.common import compute_init, autodetect_device_type, get_base_dir
from nanochat.checkpoint_manager import load_model, find_largest_model, find_last_step
from nanochat.engine import Engine
# Abuse prevention limits
@ -224,6 +224,34 @@ def validate_chat_request(request: ChatRequest):
async def lifespan(app: FastAPI):
"""Load models on all GPUs on startup."""
print("Loading nanochat models across GPUs...")
# Determine which model tag/step will be loaded (if possible) and store
# it on the app state so HTTP endpoints can report it.
try:
base_dir = get_base_dir()
model_dir = {
"base": "base_checkpoints",
"mid": "mid_checkpoints",
"sft": "chatsft_checkpoints",
"rl": "chatrl_checkpoints",
}[args.source]
checkpoints_dir = os.path.join(base_dir, model_dir)
# If user supplied explicit tag/step via CLI, prefer those
if args.model_tag is not None:
chosen_tag = args.model_tag
else:
chosen_tag = find_largest_model(checkpoints_dir)
if args.step is not None:
chosen_step = args.step
else:
chosen_step = find_last_step(os.path.join(checkpoints_dir, chosen_tag))
app.state.model_tag = chosen_tag
app.state.model_step = chosen_step
except Exception:
# If anything goes wrong (no checkpoints, etc.) just leave fields unset
app.state.model_tag = None
app.state.model_step = None
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
print(f"Server ready at http://localhost:{args.port}")
@ -400,6 +428,8 @@ async def stats():
"total_workers": len(worker_pool.workers),
"available_workers": worker_pool.available_workers.qsize(),
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
"model_tag": getattr(app.state, "model_tag", None),
"model_step": getattr(app.state, "model_step", None),
"workers": [
{
"gpu_id": w.gpu_id,