mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-08 16:59:59 +00:00
Merge f409e14d3f into dc54a1a307
This commit is contained in:
commit
06dc7486a8
|
|
@ -17,7 +17,7 @@ import signal
|
|||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from collections import deque
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from nanochat.common import compute_init, autodetect_device_type, COMPUTE_DTYPE
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -172,18 +172,16 @@ class Engine:
|
|||
self.model = model
|
||||
self.tokenizer = tokenizer # needed for tool use
|
||||
|
||||
def _get_kv_cache_dtype(self):
|
||||
"""Use the repo-wide compute dtype for inference cache allocation."""
|
||||
return COMPUTE_DTYPE
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
|
||||
"""Same as generate, but does single prefill and then clones the KV cache."""
|
||||
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
|
||||
device = self.model.get_device()
|
||||
# NOTE: setting the dtype here and in this way is an ugly hack.
|
||||
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
|
||||
# We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.
|
||||
# As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
|
||||
# I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
|
||||
# In particular, the KVCache should allocate its tensors lazily
|
||||
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
|
||||
dtype = self._get_kv_cache_dtype()
|
||||
rng = torch.Generator(device=device)
|
||||
rng.manual_seed(seed)
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ Abuse Prevention:
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import asyncio
|
||||
import logging
|
||||
|
|
@ -59,29 +60,54 @@ MAX_TOP_K = 200
|
|||
MIN_MAX_TOKENS = 1
|
||||
MAX_MAX_TOKENS = 4096
|
||||
|
||||
parser = argparse.ArgumentParser(description='NanoChat Web Server')
|
||||
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
|
||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
|
||||
parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
|
||||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for conversation traffic
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(description='NanoChat Web Server')
|
||||
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
|
||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
|
||||
parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
|
||||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||
return parser
|
||||
|
||||
|
||||
def parse_args(argv=None):
|
||||
parser = build_parser()
|
||||
if argv is None:
|
||||
argv = sys.argv[1:]
|
||||
return parser.parse_args(argv)
|
||||
|
||||
|
||||
args = parse_args([])
|
||||
device_type = None
|
||||
ddp = ddp_rank = ddp_local_rank = ddp_world_size = None
|
||||
device = None
|
||||
_runtime_initialized = False
|
||||
|
||||
|
||||
def configure_runtime(parsed_args=None):
|
||||
global args, device_type, ddp, ddp_rank, ddp_local_rank, ddp_world_size, device, _runtime_initialized
|
||||
|
||||
if parsed_args is not None:
|
||||
args = parsed_args
|
||||
if _runtime_initialized:
|
||||
return
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
_runtime_initialized = True
|
||||
|
||||
@dataclass
|
||||
class Worker:
|
||||
|
|
@ -94,9 +120,10 @@ class Worker:
|
|||
class WorkerPool:
|
||||
"""Pool of workers, each with a model replica on a different GPU."""
|
||||
|
||||
def __init__(self, num_gpus: Optional[int] = None):
|
||||
def __init__(self, runtime_device_type: str, num_gpus: Optional[int] = None):
|
||||
self.device_type = runtime_device_type
|
||||
if num_gpus is None:
|
||||
if device_type == "cuda":
|
||||
if self.device_type == "cuda":
|
||||
num_gpus = torch.cuda.device_count()
|
||||
else:
|
||||
num_gpus = 1 # e.g. cpu|mps
|
||||
|
|
@ -108,16 +135,16 @@ class WorkerPool:
|
|||
"""Load model on each GPU."""
|
||||
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
|
||||
if self.num_gpus > 1:
|
||||
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
|
||||
assert self.device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
|
||||
|
||||
for gpu_id in range(self.num_gpus):
|
||||
|
||||
if device_type == "cuda":
|
||||
if self.device_type == "cuda":
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
print(f"Loading model on GPU {gpu_id}...")
|
||||
else:
|
||||
device = torch.device(device_type) # e.g. cpu|mps
|
||||
print(f"Loading model on {device_type}...")
|
||||
device = torch.device(self.device_type) # e.g. cpu|mps
|
||||
print(f"Loading model on {self.device_type}...")
|
||||
|
||||
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||
engine = Engine(model, tokenizer)
|
||||
|
|
@ -186,7 +213,7 @@ def validate_chat_request(request: ChatRequest):
|
|||
if message.role not in ["user", "assistant"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
|
||||
detail=f"Message {i} has invalid role. Must be 'user' or 'assistant'"
|
||||
)
|
||||
|
||||
# Validate temperature
|
||||
|
|
@ -216,8 +243,9 @@ def validate_chat_request(request: ChatRequest):
|
|||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Load models on all GPUs on startup."""
|
||||
configure_runtime()
|
||||
print("Loading nanochat models across GPUs...")
|
||||
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
|
||||
app.state.worker_pool = WorkerPool(device_type, num_gpus=args.num_gpus)
|
||||
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
|
||||
print(f"Server ready at http://localhost:{args.port}")
|
||||
yield
|
||||
|
|
@ -402,6 +430,7 @@ async def stats():
|
|||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
configure_runtime(parse_args())
|
||||
print(f"Starting NanoChat Web Server")
|
||||
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
|
|
|||
38
tests/test_regressions.py
Normal file
38
tests/test_regressions.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
import importlib
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from fastapi import HTTPException
|
||||
|
||||
import nanochat.engine as engine_module
|
||||
from nanochat.engine import Engine
|
||||
|
||||
|
||||
class DummyModel:
|
||||
def get_device(self):
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def test_engine_kv_cache_uses_compute_dtype(monkeypatch):
|
||||
monkeypatch.setattr(engine_module, "COMPUTE_DTYPE", torch.float16)
|
||||
engine = Engine(DummyModel(), tokenizer=None)
|
||||
assert engine._get_kv_cache_dtype() == torch.float16
|
||||
|
||||
|
||||
def test_chat_web_rejects_system_role_with_consistent_error(monkeypatch):
|
||||
monkeypatch.setattr(sys, "argv", ["chat_web_test", "--definitely-not-a-valid-flag"])
|
||||
sys.modules.pop("scripts.chat_web", None)
|
||||
chat_web = importlib.import_module("scripts.chat_web")
|
||||
|
||||
request = chat_web.ChatRequest(
|
||||
messages=[chat_web.ChatMessage(role="system", content="You are helpful.")]
|
||||
)
|
||||
|
||||
assert chat_web._runtime_initialized is False
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
chat_web.validate_chat_request(request)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.detail == "Message 0 has invalid role. Must be 'user' or 'assistant'"
|
||||
Loading…
Reference in New Issue
Block a user