diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 4b67b62..e343fe1 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -38,7 +38,7 @@ import asyncio import logging import random from contextlib import asynccontextmanager -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from pydantic import BaseModel @@ -72,6 +72,7 @@ parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run th parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') +parser.add_argument('--root-path', type=str, default='', help='ASGI root path for proxy/gateway configurations.') args = parser.parse_args() # Configure logging for conversation traffic @@ -240,18 +241,27 @@ app.add_middleware( ) @app.get("/") -async def root(): - """Serve the chat UI.""" +async def root(request: Request): + """ + Serve the chat UI, dynamically injecting the proxy path. + """ ui_html_path = os.path.join("nanochat", "ui.html") with open(ui_html_path, "r", encoding="utf-8") as f: html_content = f.read() - # Replace the API_URL to use the same origin - html_content = html_content.replace( - "const API_URL = `http://${window.location.hostname}:8000`;", - "const API_URL = '';" - ) - return HTMLResponse(content=html_content) + # Get the prefix provided by the proxy/ASGI server. + proxy_prefix = request.scope.get('root_path', '') + + # Strip trailing slash if present + if proxy_prefix.endswith('/'): + proxy_prefix = proxy_prefix.rstrip('/') + + html_content = html_content.replace( + "const API_URL = '';", + f"const API_URL = '{proxy_prefix}';" + ) + + return HTMLResponse(content=html_content) @app.get("/logo.svg") async def logo(): @@ -412,4 +422,8 @@ if __name__ == "__main__": import uvicorn print(f"Starting NanoChat Web Server") print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}") - uvicorn.run(app, host=args.host, port=args.port) + uvicorn.run(app, + host=args.host, + port=args.port, + root_path=args.root_path + )