From efb9765d100a5c748e8b62efa9cf727a6f1ccd22 Mon Sep 17 00:00:00 2001 From: orbisai0security Date: Thu, 30 Apr 2026 13:33:47 +0000 Subject: [PATCH] fix: V-002 security vulnerability Automated security fix generated by Orbis Security AI --- scripts/chat_web.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/scripts/chat_web.py b/scripts/chat_web.py index ffaf7dab..1f821916 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -36,9 +36,9 @@ import os import torch import asyncio import logging -import random +import secrets from contextlib import asynccontextmanager -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Header, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from pydantic import BaseModel @@ -108,7 +108,8 @@ class WorkerPool: """Load model on each GPU.""" print(f"Initializing worker pool with {self.num_gpus} GPUs...") if self.num_gpus > 1: - assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not." + if device_type != "cuda": + raise ValueError("Only CUDA supports multiple workers/GPUs. cpu|mps does not.") for gpu_id in range(self.num_gpus): @@ -224,9 +225,12 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) +_cors_origins_env = os.environ.get("CHAT_ALLOWED_ORIGINS", "") +_cors_origins = [o.strip() for o in _cors_origins_env.split(",") if o.strip()] + app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=_cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -278,7 +282,7 @@ async def generate_stream( max_tokens=max_new_tokens, temperature=temperature, top_k=top_k, - seed=random.randint(0, 2**31 - 1) + seed=secrets.randbelow(2**31) ): token = token_column[0] @@ -302,7 +306,14 @@ async def generate_stream( yield f"data: {json.dumps({'done': True})}\n\n" -@app.post("/chat/completions") +def verify_api_key(x_api_key: Optional[str] = Header(default=None)): + """Verify API key from X-Api-Key header if CHAT_API_KEY env var is set.""" + api_key = os.environ.get("CHAT_API_KEY") + if api_key and x_api_key != api_key: + raise HTTPException(status_code=401, detail="Invalid or missing API key") + + +@app.post("/chat/completions", dependencies=[Depends(verify_api_key)]) async def chat_completions(request: ChatRequest): """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""