fix: V-002 security vulnerability

Automated security fix generated by Orbis Security AI
This commit is contained in:
orbisai0security 2026-04-30 13:33:47 +00:00
parent 0aaca56805
commit efb9765d10

View File

@ -36,9 +36,9 @@ import os
import torch
import asyncio
import logging
import random
import secrets
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Header, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel
@ -108,7 +108,8 @@ class WorkerPool:
"""Load model on each GPU."""
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
if self.num_gpus > 1:
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
if device_type != "cuda":
raise ValueError("Only CUDA supports multiple workers/GPUs. cpu|mps does not.")
for gpu_id in range(self.num_gpus):
@ -224,9 +225,12 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan)
_cors_origins_env = os.environ.get("CHAT_ALLOWED_ORIGINS", "")
_cors_origins = [o.strip() for o in _cors_origins_env.split(",") if o.strip()]
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_origins=_cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
@ -278,7 +282,7 @@ async def generate_stream(
max_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
seed=random.randint(0, 2**31 - 1)
seed=secrets.randbelow(2**31)
):
token = token_column[0]
@ -302,7 +306,14 @@ async def generate_stream(
yield f"data: {json.dumps({'done': True})}\n\n"
@app.post("/chat/completions")
def verify_api_key(x_api_key: Optional[str] = Header(default=None)):
"""Verify API key from X-Api-Key header if CHAT_API_KEY env var is set."""
api_key = os.environ.get("CHAT_API_KEY")
if api_key and x_api_key != api_key:
raise HTTPException(status_code=401, detail="Invalid or missing API key")
@app.post("/chat/completions", dependencies=[Depends(verify_api_key)])
async def chat_completions(request: ChatRequest):
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""