This commit is contained in:
Michael Williams 2025-11-17 13:34:50 -08:00 committed by GitHub
commit c651148338
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -45,8 +45,8 @@ from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator from typing import List, Optional, AsyncGenerator
from dataclasses import dataclass from dataclasses import dataclass
from contextlib import nullcontext from contextlib import nullcontext
from nanochat.common import compute_init, autodetect_device_type from nanochat.common import compute_init, autodetect_device_type, get_base_dir
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model, find_largest_model, find_last_step
from nanochat.engine import Engine from nanochat.engine import Engine
# Abuse prevention limits # Abuse prevention limits
@ -224,6 +224,34 @@ def validate_chat_request(request: ChatRequest):
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Load models on all GPUs on startup.""" """Load models on all GPUs on startup."""
print("Loading nanochat models across GPUs...") print("Loading nanochat models across GPUs...")
# Determine which model tag/step will be loaded (if possible) and store
# it on the app state so HTTP endpoints can report it.
try:
base_dir = get_base_dir()
model_dir = {
"base": "base_checkpoints",
"mid": "mid_checkpoints",
"sft": "chatsft_checkpoints",
"rl": "chatrl_checkpoints",
}[args.source]
checkpoints_dir = os.path.join(base_dir, model_dir)
# If user supplied explicit tag/step via CLI, prefer those
if args.model_tag is not None:
chosen_tag = args.model_tag
else:
chosen_tag = find_largest_model(checkpoints_dir)
if args.step is not None:
chosen_step = args.step
else:
chosen_step = find_last_step(os.path.join(checkpoints_dir, chosen_tag))
app.state.model_tag = chosen_tag
app.state.model_step = chosen_step
except Exception:
# If anything goes wrong (no checkpoints, etc.) just leave fields unset
app.state.model_tag = None
app.state.model_step = None
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus) app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step) await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
print(f"Server ready at http://localhost:{args.port}") print(f"Server ready at http://localhost:{args.port}")
@ -400,6 +428,8 @@ async def stats():
"total_workers": len(worker_pool.workers), "total_workers": len(worker_pool.workers),
"available_workers": worker_pool.available_workers.qsize(), "available_workers": worker_pool.available_workers.qsize(),
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(), "busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
"model_tag": getattr(app.state, "model_tag", None),
"model_step": getattr(app.state, "model_step", None),
"workers": [ "workers": [
{ {
"gpu_id": w.gpu_id, "gpu_id": w.gpu_id,