add root_path for chat web

This commit is contained in:
wang pengyu 2025-10-17 02:56:41 +00:00
parent d6d86cbf4c
commit 7afd2fd206

View File

@ -38,7 +38,7 @@ import asyncio
import logging import logging
import random import random
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel from pydantic import BaseModel
@ -70,6 +70,7 @@ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on') parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
parser.add_argument('--root-path', type=str, default='', help='ASGI root path for proxy/gateway configurations.')
args = parser.parse_args() args = parser.parse_args()
# Configure logging for conversation traffic # Configure logging for conversation traffic
@ -224,18 +225,27 @@ app.add_middleware(
) )
@app.get("/") @app.get("/")
async def root(): async def root(request: Request):
"""Serve the chat UI.""" """
Serve the chat UI, dynamically injecting the proxy path.
"""
ui_html_path = os.path.join("nanochat", "ui.html") ui_html_path = os.path.join("nanochat", "ui.html")
with open(ui_html_path, "r") as f: with open(ui_html_path, "r") as f:
html_content = f.read() html_content = f.read()
# Replace the API_URL to use the same origin
html_content = html_content.replace(
"const API_URL = `http://${window.location.hostname}:8000`;",
"const API_URL = '';"
)
return HTMLResponse(content=html_content)
# Get the prefix provided by the proxy/ASGI server.
proxy_prefix = request.scope.get('root_path', '')
# Strip trailing slash if present
if proxy_prefix.endswith('/'):
proxy_prefix = proxy_prefix.rstrip('/')
html_content = html_content.replace(
"const API_URL = '';",
f"const API_URL = '{proxy_prefix}';"
)
return HTMLResponse(content=html_content)
@app.get("/logo.svg") @app.get("/logo.svg")
async def logo(): async def logo():
@ -396,4 +406,8 @@ if __name__ == "__main__":
import uvicorn import uvicorn
print(f"Starting NanoChat Web Server") print(f"Starting NanoChat Web Server")
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}") print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
uvicorn.run(app, host=args.host, port=args.port) uvicorn.run(app,
host=args.host,
port=args.port,
root_path=args.root_path
)