mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
Merge 5172ea11bb into 4a87a0d19f
This commit is contained in:
commit
c651148338
|
|
@ -45,8 +45,8 @@ from pydantic import BaseModel
|
||||||
from typing import List, Optional, AsyncGenerator
|
from typing import List, Optional, AsyncGenerator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from nanochat.common import compute_init, autodetect_device_type
|
from nanochat.common import compute_init, autodetect_device_type, get_base_dir
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model, find_largest_model, find_last_step
|
||||||
from nanochat.engine import Engine
|
from nanochat.engine import Engine
|
||||||
|
|
||||||
# Abuse prevention limits
|
# Abuse prevention limits
|
||||||
|
|
@ -224,6 +224,34 @@ def validate_chat_request(request: ChatRequest):
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""Load models on all GPUs on startup."""
|
"""Load models on all GPUs on startup."""
|
||||||
print("Loading nanochat models across GPUs...")
|
print("Loading nanochat models across GPUs...")
|
||||||
|
|
||||||
|
# Determine which model tag/step will be loaded (if possible) and store
|
||||||
|
# it on the app state so HTTP endpoints can report it.
|
||||||
|
try:
|
||||||
|
base_dir = get_base_dir()
|
||||||
|
model_dir = {
|
||||||
|
"base": "base_checkpoints",
|
||||||
|
"mid": "mid_checkpoints",
|
||||||
|
"sft": "chatsft_checkpoints",
|
||||||
|
"rl": "chatrl_checkpoints",
|
||||||
|
}[args.source]
|
||||||
|
checkpoints_dir = os.path.join(base_dir, model_dir)
|
||||||
|
# If user supplied explicit tag/step via CLI, prefer those
|
||||||
|
if args.model_tag is not None:
|
||||||
|
chosen_tag = args.model_tag
|
||||||
|
else:
|
||||||
|
chosen_tag = find_largest_model(checkpoints_dir)
|
||||||
|
if args.step is not None:
|
||||||
|
chosen_step = args.step
|
||||||
|
else:
|
||||||
|
chosen_step = find_last_step(os.path.join(checkpoints_dir, chosen_tag))
|
||||||
|
app.state.model_tag = chosen_tag
|
||||||
|
app.state.model_step = chosen_step
|
||||||
|
except Exception:
|
||||||
|
# If anything goes wrong (no checkpoints, etc.) just leave fields unset
|
||||||
|
app.state.model_tag = None
|
||||||
|
app.state.model_step = None
|
||||||
|
|
||||||
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
|
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
|
||||||
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
|
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
|
||||||
print(f"Server ready at http://localhost:{args.port}")
|
print(f"Server ready at http://localhost:{args.port}")
|
||||||
|
|
@ -400,6 +428,8 @@ async def stats():
|
||||||
"total_workers": len(worker_pool.workers),
|
"total_workers": len(worker_pool.workers),
|
||||||
"available_workers": worker_pool.available_workers.qsize(),
|
"available_workers": worker_pool.available_workers.qsize(),
|
||||||
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
|
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
|
||||||
|
"model_tag": getattr(app.state, "model_tag", None),
|
||||||
|
"model_step": getattr(app.state, "model_step", None),
|
||||||
"workers": [
|
"workers": [
|
||||||
{
|
{
|
||||||
"gpu_id": w.gpu_id,
|
"gpu_id": w.gpu_id,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user