Compare commits

...

5 Commits

Author SHA1 Message Date
Pengyu Wang
32f9a3236d
Merge 7950813a41 into 4a87a0d19f 2025-11-18 10:41:03 +09:00
Andrej
4a87a0d19f
Merge pull request #299 from samjabrahams/rotary_embedding_head_dim_comment_cleanup
Fix comment: rotary embeddings final dimension size
2025-11-17 13:29:21 -08:00
Sam Abrahams
11e68bf442 Fix comment: rotary embeddings final dimension size 2025-11-17 11:32:56 -05:00
svlandeg
7950813a41 make changes more minimal 2025-11-14 09:49:34 +01:00
wang pengyu
7afd2fd206 add root_path for chat web 2025-10-17 02:56:41 +00:00
2 changed files with 12 additions and 9 deletions

View File

@ -244,7 +244,7 @@ class GPT(nn.Module):
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"

View File

@ -38,7 +38,7 @@ import asyncio
import logging
import random
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel
@ -72,6 +72,7 @@ parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run th
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
parser.add_argument('--root-path', type=str, default='', help='ASGI root path for proxy/gateway configurations')
args = parser.parse_args()
# Configure logging for conversation traffic
@ -240,19 +241,21 @@ app.add_middleware(
)
@app.get("/")
async def root():
"""Serve the chat UI."""
async def root(request: Request):
"""Serve the chat UI, dynamically injecting the proxy path."""
ui_html_path = os.path.join("nanochat", "ui.html")
with open(ui_html_path, "r", encoding="utf-8") as f:
html_content = f.read()
# Replace the API_URL to use the same origin
# Get the prefix provided by the proxy/ASGI server.
proxy_prefix = request.scope.get('root_path', '').rstrip('/')
html_content = html_content.replace(
"const API_URL = `http://${window.location.hostname}:8000`;",
"const API_URL = '';"
"const API_URL = '';",
f"const API_URL = '{proxy_prefix}';"
)
return HTMLResponse(content=html_content)
@app.get("/logo.svg")
async def logo():
"""Serve the NanoChat logo for favicon and header."""
@ -412,4 +415,4 @@ if __name__ == "__main__":
import uvicorn
print(f"Starting NanoChat Web Server")
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
uvicorn.run(app, host=args.host, port=args.port)
uvicorn.run(app, host=args.host, port=args.port, root_path=args.root_path)