Fix chat validation message and engine KV cache dtype

This commit is contained in:
Sang Hun Kim 2026-04-11 15:09:18 +09:00
parent a445144d39
commit 63395bbade
3 changed files with 67 additions and 21 deletions

View File

@ -17,7 +17,7 @@ import signal
import warnings
from contextlib import contextmanager
from collections import deque
from nanochat.common import compute_init, autodetect_device_type
from nanochat.common import compute_init, autodetect_device_type, COMPUTE_DTYPE
from nanochat.checkpoint_manager import load_model
# -----------------------------------------------------------------------------
@ -172,18 +172,16 @@ class Engine:
self.model = model
self.tokenizer = tokenizer # needed for tool use
def _get_kv_cache_dtype(self):
"""Use the repo-wide compute dtype for inference cache allocation."""
return COMPUTE_DTYPE
@torch.inference_mode()
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
"""Same as generate, but does single prefill and then clones the KV cache."""
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
device = self.model.get_device()
# NOTE: setting the dtype here and in this way is an ugly hack.
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
# We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.
# As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
# I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
# In particular, the KVCache should allocate its tensors lazily
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
dtype = self._get_kv_cache_dtype()
rng = torch.Generator(device=device)
rng.manual_seed(seed)

View File

@ -33,6 +33,7 @@ Abuse Prevention:
import argparse
import json
import os
import sys
import torch
import asyncio
import logging
@ -59,18 +60,29 @@ MAX_TOP_K = 200
MIN_MAX_TOKENS = 1
MAX_MAX_TOKENS = 4096
parser = argparse.ArgumentParser(description='NanoChat Web Server')
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
args = parser.parse_args()
def build_parser():
parser = argparse.ArgumentParser(description='NanoChat Web Server')
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
return parser
def parse_args(argv=None):
parser = build_parser()
if argv is None:
argv = sys.argv[1:]
return parser.parse_args(argv)
args = parse_args()
# Configure logging for conversation traffic
logging.basicConfig(
@ -186,7 +198,7 @@ def validate_chat_request(request: ChatRequest):
if message.role not in ["user", "assistant"]:
raise HTTPException(
status_code=400,
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
detail=f"Message {i} has invalid role. Must be 'user' or 'assistant'"
)
# Validate temperature

36
tests/test_regressions.py Normal file
View File

@ -0,0 +1,36 @@
import importlib
import sys
import pytest
import torch
from fastapi import HTTPException
import nanochat.engine as engine_module
from nanochat.engine import Engine
class DummyModel:
def get_device(self):
return torch.device("cpu")
def test_engine_kv_cache_uses_compute_dtype(monkeypatch):
monkeypatch.setattr(engine_module, "COMPUTE_DTYPE", torch.float16)
engine = Engine(DummyModel(), tokenizer=None)
assert engine._get_kv_cache_dtype() == torch.float16
def test_chat_web_rejects_system_role_with_consistent_error(monkeypatch):
monkeypatch.setattr(sys, "argv", ["chat_web_test"])
sys.modules.pop("scripts.chat_web", None)
chat_web = importlib.import_module("scripts.chat_web")
request = chat_web.ChatRequest(
messages=[chat_web.ChatMessage(role="system", content="You are helpful.")]
)
with pytest.raises(HTTPException) as exc_info:
chat_web.validate_chat_request(request)
assert exc_info.value.status_code == 400
assert exc_info.value.detail == "Message 0 has invalid role. Must be 'user' or 'assistant'"