Compare commits

...

4 Commits

Author SHA1 Message Date
Pengyu Wang
0de282ac4a
Merge 7950813a41 into bc1fca39f3 2025-11-15 22:31:03 +04:00
Andrej Karpathy
bc1fca39f3 mqa -> gqa to reduce confusion 2025-11-15 15:43:37 +00:00
svlandeg
7950813a41 make changes more minimal 2025-11-14 09:49:34 +01:00
wang pengyu
7afd2fd206 add root_path for chat web 2025-10-17 02:56:41 +00:00
2 changed files with 13 additions and 10 deletions

View File

@ -8,7 +8,7 @@ Notable features:
- norm after token embedding - norm after token embedding
- no learnable params in rmsnorm - no learnable params in rmsnorm
- no bias in linear layers - no bias in linear layers
- Multi-Query Attention (MQA) support for more efficient inference - Group-Query Attention (GQA) support for more efficient inference
""" """
import math import math
@ -29,7 +29,7 @@ class GPTConfig:
vocab_size: int = 50304 vocab_size: int = 50304
n_layer: int = 12 n_layer: int = 12
n_head: int = 6 # number of query heads n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (MQA) n_kv_head: int = 6 # number of key/value heads (GQA)
n_embd: int = 768 n_embd: int = 768

View File

@ -38,7 +38,7 @@ import asyncio
import logging import logging
import random import random
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel from pydantic import BaseModel
@ -72,6 +72,7 @@ parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run th
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
parser.add_argument('--root-path', type=str, default='', help='ASGI root path for proxy/gateway configurations')
args = parser.parse_args() args = parser.parse_args()
# Configure logging for conversation traffic # Configure logging for conversation traffic
@ -240,19 +241,21 @@ app.add_middleware(
) )
@app.get("/") @app.get("/")
async def root(): async def root(request: Request):
"""Serve the chat UI.""" """Serve the chat UI, dynamically injecting the proxy path."""
ui_html_path = os.path.join("nanochat", "ui.html") ui_html_path = os.path.join("nanochat", "ui.html")
with open(ui_html_path, "r", encoding="utf-8") as f: with open(ui_html_path, "r", encoding="utf-8") as f:
html_content = f.read() html_content = f.read()
# Replace the API_URL to use the same origin
# Get the prefix provided by the proxy/ASGI server.
proxy_prefix = request.scope.get('root_path', '').rstrip('/')
html_content = html_content.replace( html_content = html_content.replace(
"const API_URL = `http://${window.location.hostname}:8000`;", "const API_URL = '';",
"const API_URL = '';" f"const API_URL = '{proxy_prefix}';"
) )
return HTMLResponse(content=html_content) return HTMLResponse(content=html_content)
@app.get("/logo.svg") @app.get("/logo.svg")
async def logo(): async def logo():
"""Serve the NanoChat logo for favicon and header.""" """Serve the NanoChat logo for favicon and header."""
@ -412,4 +415,4 @@ if __name__ == "__main__":
import uvicorn import uvicorn
print(f"Starting NanoChat Web Server") print(f"Starting NanoChat Web Server")
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}") print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
uvicorn.run(app, host=args.host, port=args.port) uvicorn.run(app, host=args.host, port=args.port, root_path=args.root_path)