diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 4b67b62..7a2359c 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -38,7 +38,7 @@ import asyncio import logging import random from contextlib import asynccontextmanager -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from pydantic import BaseModel @@ -72,6 +72,7 @@ parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run th parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') +parser.add_argument('--root-path', type=str, default='', help='ASGI root path for proxy/gateway configurations') args = parser.parse_args() # Configure logging for conversation traffic @@ -240,19 +241,21 @@ app.add_middleware( ) @app.get("/") -async def root(): - """Serve the chat UI.""" +async def root(request: Request): + """Serve the chat UI, dynamically injecting the proxy path.""" ui_html_path = os.path.join("nanochat", "ui.html") with open(ui_html_path, "r", encoding="utf-8") as f: html_content = f.read() - # Replace the API_URL to use the same origin + + # Get the prefix provided by the proxy/ASGI server. + proxy_prefix = request.scope.get('root_path', '').rstrip('/') + html_content = html_content.replace( - "const API_URL = `http://${window.location.hostname}:8000`;", - "const API_URL = '';" + "const API_URL = '';", + f"const API_URL = '{proxy_prefix}';" ) return HTMLResponse(content=html_content) - @app.get("/logo.svg") async def logo(): """Serve the NanoChat logo for favicon and header.""" @@ -412,4 +415,4 @@ if __name__ == "__main__": import uvicorn print(f"Starting NanoChat Web Server") print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}") - uvicorn.run(app, host=args.host, port=args.port) + uvicorn.run(app, host=args.host, port=args.port, root_path=args.root_path)