mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Compare commits
4 Commits
03d69d58c1
...
0de282ac4a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0de282ac4a | ||
|
|
bc1fca39f3 | ||
|
|
7950813a41 | ||
|
|
7afd2fd206 |
|
|
@ -8,7 +8,7 @@ Notable features:
|
||||||
- norm after token embedding
|
- norm after token embedding
|
||||||
- no learnable params in rmsnorm
|
- no learnable params in rmsnorm
|
||||||
- no bias in linear layers
|
- no bias in linear layers
|
||||||
- Multi-Query Attention (MQA) support for more efficient inference
|
- Group-Query Attention (GQA) support for more efficient inference
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
@ -29,7 +29,7 @@ class GPTConfig:
|
||||||
vocab_size: int = 50304
|
vocab_size: int = 50304
|
||||||
n_layer: int = 12
|
n_layer: int = 12
|
||||||
n_head: int = 6 # number of query heads
|
n_head: int = 6 # number of query heads
|
||||||
n_kv_head: int = 6 # number of key/value heads (MQA)
|
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||||
n_embd: int = 768
|
n_embd: int = 768
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
@ -72,6 +72,7 @@ parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run th
|
||||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||||
|
parser.add_argument('--root-path', type=str, default='', help='ASGI root path for proxy/gateway configurations')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Configure logging for conversation traffic
|
# Configure logging for conversation traffic
|
||||||
|
|
@ -240,19 +241,21 @@ app.add_middleware(
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def root():
|
async def root(request: Request):
|
||||||
"""Serve the chat UI."""
|
"""Serve the chat UI, dynamically injecting the proxy path."""
|
||||||
ui_html_path = os.path.join("nanochat", "ui.html")
|
ui_html_path = os.path.join("nanochat", "ui.html")
|
||||||
with open(ui_html_path, "r", encoding="utf-8") as f:
|
with open(ui_html_path, "r", encoding="utf-8") as f:
|
||||||
html_content = f.read()
|
html_content = f.read()
|
||||||
# Replace the API_URL to use the same origin
|
|
||||||
|
# Get the prefix provided by the proxy/ASGI server.
|
||||||
|
proxy_prefix = request.scope.get('root_path', '').rstrip('/')
|
||||||
|
|
||||||
html_content = html_content.replace(
|
html_content = html_content.replace(
|
||||||
"const API_URL = `http://${window.location.hostname}:8000`;",
|
"const API_URL = '';",
|
||||||
"const API_URL = '';"
|
f"const API_URL = '{proxy_prefix}';"
|
||||||
)
|
)
|
||||||
return HTMLResponse(content=html_content)
|
return HTMLResponse(content=html_content)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/logo.svg")
|
@app.get("/logo.svg")
|
||||||
async def logo():
|
async def logo():
|
||||||
"""Serve the NanoChat logo for favicon and header."""
|
"""Serve the NanoChat logo for favicon and header."""
|
||||||
|
|
@ -412,4 +415,4 @@ if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
print(f"Starting NanoChat Web Server")
|
print(f"Starting NanoChat Web Server")
|
||||||
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
|
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
|
||||||
uvicorn.run(app, host=args.host, port=args.port)
|
uvicorn.run(app, host=args.host, port=args.port, root_path=args.root_path)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user