Refactor chat_web startup for safe imports

This commit is contained in:
Sang Hun Kim 2026-04-11 15:24:21 +09:00
parent 63395bbade
commit 7b4549495c
2 changed files with 37 additions and 18 deletions

View File

@ -60,6 +60,8 @@ MAX_TOP_K = 200
MIN_MAX_TOKENS = 1
MAX_MAX_TOKENS = 4096
logger = logging.getLogger(__name__)
def build_parser():
parser = argparse.ArgumentParser(description='NanoChat Web Server')
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
@ -82,18 +84,30 @@ def parse_args(argv=None):
return parser.parse_args(argv)
args = parse_args()
args = parse_args([])
device_type = None
ddp = ddp_rank = ddp_local_rank = ddp_world_size = None
device = None
_runtime_initialized = False
# Configure logging for conversation traffic
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
def configure_runtime(parsed_args=None):
global args, device_type, ddp, ddp_rank, ddp_local_rank, ddp_world_size, device, _runtime_initialized
if parsed_args is not None:
args = parsed_args
if _runtime_initialized:
return
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
_runtime_initialized = True
@dataclass
class Worker:
@ -106,9 +120,10 @@ class Worker:
class WorkerPool:
"""Pool of workers, each with a model replica on a different GPU."""
def __init__(self, num_gpus: Optional[int] = None):
def __init__(self, runtime_device_type: str, num_gpus: Optional[int] = None):
self.device_type = runtime_device_type
if num_gpus is None:
if device_type == "cuda":
if self.device_type == "cuda":
num_gpus = torch.cuda.device_count()
else:
num_gpus = 1 # e.g. cpu|mps
@ -120,16 +135,16 @@ class WorkerPool:
"""Load model on each GPU."""
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
if self.num_gpus > 1:
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
assert self.device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
for gpu_id in range(self.num_gpus):
if device_type == "cuda":
if self.device_type == "cuda":
device = torch.device(f"cuda:{gpu_id}")
print(f"Loading model on GPU {gpu_id}...")
else:
device = torch.device(device_type) # e.g. cpu|mps
print(f"Loading model on {device_type}...")
device = torch.device(self.device_type) # e.g. cpu|mps
print(f"Loading model on {self.device_type}...")
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer)
@ -228,8 +243,9 @@ def validate_chat_request(request: ChatRequest):
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load models on all GPUs on startup."""
configure_runtime()
print("Loading nanochat models across GPUs...")
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
app.state.worker_pool = WorkerPool(device_type, num_gpus=args.num_gpus)
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
print(f"Server ready at http://localhost:{args.port}")
yield
@ -414,6 +430,7 @@ async def stats():
if __name__ == "__main__":
import uvicorn
configure_runtime(parse_args())
print(f"Starting NanoChat Web Server")
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
uvicorn.run(app, host=args.host, port=args.port)

View File

@ -21,7 +21,7 @@ def test_engine_kv_cache_uses_compute_dtype(monkeypatch):
def test_chat_web_rejects_system_role_with_consistent_error(monkeypatch):
monkeypatch.setattr(sys, "argv", ["chat_web_test"])
monkeypatch.setattr(sys, "argv", ["chat_web_test", "--definitely-not-a-valid-flag"])
sys.modules.pop("scripts.chat_web", None)
chat_web = importlib.import_module("scripts.chat_web")
@ -29,6 +29,8 @@ def test_chat_web_rejects_system_role_with_consistent_error(monkeypatch):
messages=[chat_web.ChatMessage(role="system", content="You are helpful.")]
)
assert chat_web._runtime_initialized is False
with pytest.raises(HTTPException) as exc_info:
chat_web.validate_chat_request(request)