From 63395bbadebe2509d9f94b61ce801d059632853d Mon Sep 17 00:00:00 2001 From: Sang Hun Kim Date: Sat, 11 Apr 2026 15:09:18 +0900 Subject: [PATCH 1/3] Fix chat validation message and engine KV cache dtype --- nanochat/engine.py | 14 ++++++-------- scripts/chat_web.py | 38 +++++++++++++++++++++++++------------- tests/test_regressions.py | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 21 deletions(-) create mode 100644 tests/test_regressions.py diff --git a/nanochat/engine.py b/nanochat/engine.py index aa2e6a98..f133d891 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -17,7 +17,7 @@ import signal import warnings from contextlib import contextmanager from collections import deque -from nanochat.common import compute_init, autodetect_device_type +from nanochat.common import compute_init, autodetect_device_type, COMPUTE_DTYPE from nanochat.checkpoint_manager import load_model # ----------------------------------------------------------------------------- @@ -172,18 +172,16 @@ class Engine: self.model = model self.tokenizer = tokenizer # needed for tool use + def _get_kv_cache_dtype(self): + """Use the repo-wide compute dtype for inference cache allocation.""" + return COMPUTE_DTYPE + @torch.inference_mode() def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42): """Same as generate, but does single prefill and then clones the KV cache.""" assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints" device = self.model.get_device() - # NOTE: setting the dtype here and in this way is an ugly hack. - # Currently the repo assumes that cuda -> bfloat16 and everything else -> float32. - # We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors. - # As a quick hack, we're making generate() function inherit and know about this repo-wise assumption. - # I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase. - # In particular, the KVCache should allocate its tensors lazily - dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 + dtype = self._get_kv_cache_dtype() rng = torch.Generator(device=device) rng.manual_seed(seed) diff --git a/scripts/chat_web.py b/scripts/chat_web.py index ffaf7dab..9cce083e 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -33,6 +33,7 @@ Abuse Prevention: import argparse import json import os +import sys import torch import asyncio import logging @@ -59,18 +60,29 @@ MAX_TOP_K = 200 MIN_MAX_TOKENS = 1 MAX_MAX_TOKENS = 4096 -parser = argparse.ArgumentParser(description='NanoChat Web Server') -parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)') -parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl") -parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation') -parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter') -parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation') -parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') -parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') -parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on') -parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') -parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') -args = parser.parse_args() +def build_parser(): + parser = argparse.ArgumentParser(description='NanoChat Web Server') + parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)') + parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl") + parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation') + parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter') + parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation') + parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') + parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') + parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on') + parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') + parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') + return parser + + +def parse_args(argv=None): + parser = build_parser() + if argv is None: + argv = sys.argv[1:] + return parser.parse_args(argv) + + +args = parse_args() # Configure logging for conversation traffic logging.basicConfig( @@ -186,7 +198,7 @@ def validate_chat_request(request: ChatRequest): if message.role not in ["user", "assistant"]: raise HTTPException( status_code=400, - detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'" + detail=f"Message {i} has invalid role. Must be 'user' or 'assistant'" ) # Validate temperature diff --git a/tests/test_regressions.py b/tests/test_regressions.py new file mode 100644 index 00000000..ed332d1f --- /dev/null +++ b/tests/test_regressions.py @@ -0,0 +1,36 @@ +import importlib +import sys + +import pytest +import torch +from fastapi import HTTPException + +import nanochat.engine as engine_module +from nanochat.engine import Engine + + +class DummyModel: + def get_device(self): + return torch.device("cpu") + + +def test_engine_kv_cache_uses_compute_dtype(monkeypatch): + monkeypatch.setattr(engine_module, "COMPUTE_DTYPE", torch.float16) + engine = Engine(DummyModel(), tokenizer=None) + assert engine._get_kv_cache_dtype() == torch.float16 + + +def test_chat_web_rejects_system_role_with_consistent_error(monkeypatch): + monkeypatch.setattr(sys, "argv", ["chat_web_test"]) + sys.modules.pop("scripts.chat_web", None) + chat_web = importlib.import_module("scripts.chat_web") + + request = chat_web.ChatRequest( + messages=[chat_web.ChatMessage(role="system", content="You are helpful.")] + ) + + with pytest.raises(HTTPException) as exc_info: + chat_web.validate_chat_request(request) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Message 0 has invalid role. Must be 'user' or 'assistant'" From 7b4549495c105f4c39ec12269e5ee0a45a3a80d2 Mon Sep 17 00:00:00 2001 From: Sang Hun Kim Date: Sat, 11 Apr 2026 15:24:21 +0900 Subject: [PATCH 2/3] Refactor chat_web startup for safe imports --- scripts/chat_web.py | 51 ++++++++++++++++++++++++++------------- tests/test_regressions.py | 4 ++- 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 9cce083e..4c8b5834 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -60,6 +60,8 @@ MAX_TOP_K = 200 MIN_MAX_TOKENS = 1 MAX_MAX_TOKENS = 4096 +logger = logging.getLogger(__name__) + def build_parser(): parser = argparse.ArgumentParser(description='NanoChat Web Server') parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)') @@ -82,18 +84,30 @@ def parse_args(argv=None): return parser.parse_args(argv) -args = parse_args() +args = parse_args([]) +device_type = None +ddp = ddp_rank = ddp_local_rank = ddp_world_size = None +device = None +_runtime_initialized = False -# Configure logging for conversation traffic -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' -) -logger = logging.getLogger(__name__) -device_type = autodetect_device_type() if args.device_type == "" else args.device_type -ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) +def configure_runtime(parsed_args=None): + global args, device_type, ddp, ddp_rank, ddp_local_rank, ddp_world_size, device, _runtime_initialized + + if parsed_args is not None: + args = parsed_args + if _runtime_initialized: + return + + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + device_type = autodetect_device_type() if args.device_type == "" else args.device_type + ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) + _runtime_initialized = True @dataclass class Worker: @@ -106,9 +120,10 @@ class Worker: class WorkerPool: """Pool of workers, each with a model replica on a different GPU.""" - def __init__(self, num_gpus: Optional[int] = None): + def __init__(self, runtime_device_type: str, num_gpus: Optional[int] = None): + self.device_type = runtime_device_type if num_gpus is None: - if device_type == "cuda": + if self.device_type == "cuda": num_gpus = torch.cuda.device_count() else: num_gpus = 1 # e.g. cpu|mps @@ -120,16 +135,16 @@ class WorkerPool: """Load model on each GPU.""" print(f"Initializing worker pool with {self.num_gpus} GPUs...") if self.num_gpus > 1: - assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not." + assert self.device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not." for gpu_id in range(self.num_gpus): - if device_type == "cuda": + if self.device_type == "cuda": device = torch.device(f"cuda:{gpu_id}") print(f"Loading model on GPU {gpu_id}...") else: - device = torch.device(device_type) # e.g. cpu|mps - print(f"Loading model on {device_type}...") + device = torch.device(self.device_type) # e.g. cpu|mps + print(f"Loading model on {self.device_type}...") model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step) engine = Engine(model, tokenizer) @@ -228,8 +243,9 @@ def validate_chat_request(request: ChatRequest): @asynccontextmanager async def lifespan(app: FastAPI): """Load models on all GPUs on startup.""" + configure_runtime() print("Loading nanochat models across GPUs...") - app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus) + app.state.worker_pool = WorkerPool(device_type, num_gpus=args.num_gpus) await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step) print(f"Server ready at http://localhost:{args.port}") yield @@ -414,6 +430,7 @@ async def stats(): if __name__ == "__main__": import uvicorn + configure_runtime(parse_args()) print(f"Starting NanoChat Web Server") print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}") uvicorn.run(app, host=args.host, port=args.port) diff --git a/tests/test_regressions.py b/tests/test_regressions.py index ed332d1f..a8dbea5d 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -21,7 +21,7 @@ def test_engine_kv_cache_uses_compute_dtype(monkeypatch): def test_chat_web_rejects_system_role_with_consistent_error(monkeypatch): - monkeypatch.setattr(sys, "argv", ["chat_web_test"]) + monkeypatch.setattr(sys, "argv", ["chat_web_test", "--definitely-not-a-valid-flag"]) sys.modules.pop("scripts.chat_web", None) chat_web = importlib.import_module("scripts.chat_web") @@ -29,6 +29,8 @@ def test_chat_web_rejects_system_role_with_consistent_error(monkeypatch): messages=[chat_web.ChatMessage(role="system", content="You are helpful.")] ) + assert chat_web._runtime_initialized is False + with pytest.raises(HTTPException) as exc_info: chat_web.validate_chat_request(request) From f409e14d3f4371bd4fd8c2e77aa9402caa1be1df Mon Sep 17 00:00:00 2001 From: Sang Hun Kim Date: Fri, 17 Apr 2026 18:36:37 +0900 Subject: [PATCH 3/3] Verify tests pass with uv environment