mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
Merge 5172ea11bb into 4a87a0d19f
This commit is contained in:
commit
c651148338
|
|
@ -45,8 +45,8 @@ from pydantic import BaseModel
|
|||
from typing import List, Optional, AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import compute_init, autodetect_device_type, get_base_dir
|
||||
from nanochat.checkpoint_manager import load_model, find_largest_model, find_last_step
|
||||
from nanochat.engine import Engine
|
||||
|
||||
# Abuse prevention limits
|
||||
|
|
@ -224,6 +224,34 @@ def validate_chat_request(request: ChatRequest):
|
|||
async def lifespan(app: FastAPI):
|
||||
"""Load models on all GPUs on startup."""
|
||||
print("Loading nanochat models across GPUs...")
|
||||
|
||||
# Determine which model tag/step will be loaded (if possible) and store
|
||||
# it on the app state so HTTP endpoints can report it.
|
||||
try:
|
||||
base_dir = get_base_dir()
|
||||
model_dir = {
|
||||
"base": "base_checkpoints",
|
||||
"mid": "mid_checkpoints",
|
||||
"sft": "chatsft_checkpoints",
|
||||
"rl": "chatrl_checkpoints",
|
||||
}[args.source]
|
||||
checkpoints_dir = os.path.join(base_dir, model_dir)
|
||||
# If user supplied explicit tag/step via CLI, prefer those
|
||||
if args.model_tag is not None:
|
||||
chosen_tag = args.model_tag
|
||||
else:
|
||||
chosen_tag = find_largest_model(checkpoints_dir)
|
||||
if args.step is not None:
|
||||
chosen_step = args.step
|
||||
else:
|
||||
chosen_step = find_last_step(os.path.join(checkpoints_dir, chosen_tag))
|
||||
app.state.model_tag = chosen_tag
|
||||
app.state.model_step = chosen_step
|
||||
except Exception:
|
||||
# If anything goes wrong (no checkpoints, etc.) just leave fields unset
|
||||
app.state.model_tag = None
|
||||
app.state.model_step = None
|
||||
|
||||
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
|
||||
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
|
||||
print(f"Server ready at http://localhost:{args.port}")
|
||||
|
|
@ -400,6 +428,8 @@ async def stats():
|
|||
"total_workers": len(worker_pool.workers),
|
||||
"available_workers": worker_pool.available_workers.qsize(),
|
||||
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
|
||||
"model_tag": getattr(app.state, "model_tag", None),
|
||||
"model_step": getattr(app.state, "model_step", None),
|
||||
"workers": [
|
||||
{
|
||||
"gpu_id": w.gpu_id,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user