mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Compare commits
5 Commits
74d7a76ee0
...
4a31d47687
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4a31d47687 | ||
|
|
4a87a0d19f | ||
|
|
11e68bf442 | ||
|
|
04862cbfea | ||
|
|
28e6b9b9c2 |
28
README.md
28
README.md
|
|
@ -32,6 +32,34 @@ python -m scripts.chat_web
|
||||||
|
|
||||||
And then visit the URL shown. Make sure to access it correctly, e.g. on Lambda use the public IP of the node you're on, followed by the port, so for example [http://209.20.xxx.xxx:8000/](http://209.20.xxx.xxx:8000/), etc. Then talk to your LLM as you'd normally talk to ChatGPT! Get it to write stories or poems. Ask it to tell you who you are to see a hallucination. Ask it why the sky is blue. Or why it's green. The speedrun is a 4e19 FLOPs capability model so it's a bit like talking to a kindergartener :).
|
And then visit the URL shown. Make sure to access it correctly, e.g. on Lambda use the public IP of the node you're on, followed by the port, so for example [http://209.20.xxx.xxx:8000/](http://209.20.xxx.xxx:8000/), etc. Then talk to your LLM as you'd normally talk to ChatGPT! Get it to write stories or poems. Ask it to tell you who you are to see a hallucination. Ask it why the sky is blue. Or why it's green. The speedrun is a 4e19 FLOPs capability model so it's a bit like talking to a kindergartener :).
|
||||||
|
|
||||||
|
### CPU Inference
|
||||||
|
|
||||||
|
If you want to run inference on CPU (e.g., on your laptop or a machine without GPU), use the CPU web server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m scripts.chat_web_cpu --model-dir /tmp/nanochat
|
||||||
|
```
|
||||||
|
|
||||||
|
This script automatically converts the model to float32 and runs inference on CPU. You can then access the web UI at `http://localhost:8000` or use it via the OpenAI-compatible API.
|
||||||
|
|
||||||
|
CPU web server (`chat_web_cpu.py`) is compatible with the OpenAI API specification. This means you can use any OpenAI SDK, tool, or framework with your NanoChat models:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
api_key="not_set"
|
||||||
|
base_url="http://localhost:8000/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="nanochat",
|
||||||
|
messages=[{"role": "user", "content": "Hello!"}]
|
||||||
|
)
|
||||||
|
print(response.choices[0].message.content)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
<img width="2672" height="1520" alt="image" src="https://github.com/user-attachments/assets/ed39ddf8-2370-437a-bedc-0f39781e76b5" />
|
<img width="2672" height="1520" alt="image" src="https://github.com/user-attachments/assets/ed39ddf8-2370-437a-bedc-0f39781e76b5" />
|
||||||
|
|
|
||||||
|
|
@ -244,7 +244,7 @@ class GPT(nn.Module):
|
||||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||||
B, T = idx.size()
|
B, T = idx.size()
|
||||||
|
|
||||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
||||||
|
|
|
||||||
764
scripts/chat_web_cpu.py
Normal file
764
scripts/chat_web_cpu.py
Normal file
|
|
@ -0,0 +1,764 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
CPU-compatible web chat server - serves both UI and API from a single FastAPI instance.
|
||||||
|
Run with: python chat_web_cpu.py --model-dir /path/to/model
|
||||||
|
Then open http://localhost:8000 in your browser.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import pickle
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
import torch
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse, JSONResponse
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Optional, AsyncGenerator, Literal, Union, Dict, Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Minimal GPT implementation (copied from generate_cpu.py)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GPTConfig:
|
||||||
|
sequence_len: int = 1024
|
||||||
|
vocab_size: int = 50304
|
||||||
|
n_layer: int = 12
|
||||||
|
n_head: int = 6
|
||||||
|
n_kv_head: int = 6
|
||||||
|
n_embd: int = 768
|
||||||
|
|
||||||
|
|
||||||
|
def norm(x):
|
||||||
|
return F.rms_norm(x, (x.size(-1),))
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(x, cos, sin):
|
||||||
|
assert x.ndim == 4
|
||||||
|
d = x.shape[3] // 2
|
||||||
|
x1, x2 = x[..., :d], x[..., d:]
|
||||||
|
y1 = x1 * cos + x2 * sin
|
||||||
|
y2 = x1 * (-sin) + x2 * cos
|
||||||
|
out = torch.cat([y1, y2], 3)
|
||||||
|
out = out.to(x.dtype)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(x, n_rep):
|
||||||
|
if n_rep == 1:
|
||||||
|
return x
|
||||||
|
bs, n_kv_heads, slen, head_dim = x.shape
|
||||||
|
return (
|
||||||
|
x[:, :, None, :, :]
|
||||||
|
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
|
||||||
|
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CausalSelfAttention(nn.Module):
|
||||||
|
def __init__(self, config, layer_idx):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.n_head = config.n_head
|
||||||
|
self.n_kv_head = config.n_kv_head
|
||||||
|
self.n_embd = config.n_embd
|
||||||
|
self.head_dim = self.n_embd // self.n_head
|
||||||
|
assert self.n_embd % self.n_head == 0
|
||||||
|
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
||||||
|
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
||||||
|
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||||
|
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||||
|
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x, cos_sin, kv_cache):
|
||||||
|
B, T, C = x.size()
|
||||||
|
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
||||||
|
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||||
|
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||||
|
cos, sin = cos_sin
|
||||||
|
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
||||||
|
q, k = norm(q), norm(k)
|
||||||
|
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
||||||
|
if kv_cache is not None:
|
||||||
|
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
|
||||||
|
Tq = q.size(2)
|
||||||
|
Tk = k.size(2)
|
||||||
|
nrep = self.n_head // self.n_kv_head
|
||||||
|
k, v = repeat_kv(k, nrep), repeat_kv(v, nrep)
|
||||||
|
if kv_cache is None or Tq == Tk:
|
||||||
|
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||||
|
elif Tq == 1:
|
||||||
|
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
|
||||||
|
else:
|
||||||
|
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device)
|
||||||
|
prefix_len = Tk - Tq
|
||||||
|
if prefix_len > 0:
|
||||||
|
attn_mask[:, :prefix_len] = True
|
||||||
|
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
||||||
|
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
||||||
|
y = self.c_proj(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
||||||
|
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.c_fc(x)
|
||||||
|
x = F.relu(x).square()
|
||||||
|
x = self.c_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(self, config, layer_idx):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = CausalSelfAttention(config, layer_idx)
|
||||||
|
self.mlp = MLP(config)
|
||||||
|
|
||||||
|
def forward(self, x, cos_sin, kv_cache):
|
||||||
|
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
||||||
|
x = x + self.mlp(norm(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GPT(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.transformer = nn.ModuleDict({
|
||||||
|
"wte": nn.Embedding(config.vocab_size, config.n_embd),
|
||||||
|
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||||
|
})
|
||||||
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||||
|
self.rotary_seq_len = config.sequence_len * 10
|
||||||
|
head_dim = config.n_embd // config.n_head
|
||||||
|
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||||
|
self.register_buffer("cos", cos, persistent=False)
|
||||||
|
self.register_buffer("sin", sin, persistent=False)
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
torch.nn.init.zeros_(self.lm_head.weight)
|
||||||
|
for block in self.transformer.h:
|
||||||
|
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||||
|
torch.nn.init.zeros_(block.attn.c_proj.weight)
|
||||||
|
head_dim = self.config.n_embd // self.config.n_head
|
||||||
|
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||||
|
self.cos, self.sin = cos, sin
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
fan_out = module.weight.size(0)
|
||||||
|
fan_in = module.weight.size(1)
|
||||||
|
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
||||||
|
if module.bias is not None:
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
|
||||||
|
|
||||||
|
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||||
|
if device is None:
|
||||||
|
device = self.transformer.wte.weight.device
|
||||||
|
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
||||||
|
inv_freq = 1.0 / (base ** (channel_range / head_dim))
|
||||||
|
t = torch.arange(seq_len, dtype=torch.float32, device=device)
|
||||||
|
freqs = torch.outer(t, inv_freq)
|
||||||
|
cos, sin = freqs.cos(), freqs.sin()
|
||||||
|
cos, sin = cos[None, :, None, :], sin[None, :, None, :]
|
||||||
|
return cos, sin
|
||||||
|
|
||||||
|
def forward(self, idx, targets=None, kv_cache=None):
|
||||||
|
B, T = idx.size()
|
||||||
|
assert T <= self.cos.size(1)
|
||||||
|
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||||
|
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
|
||||||
|
x = self.transformer.wte(idx)
|
||||||
|
x = norm(x)
|
||||||
|
for block in self.transformer.h:
|
||||||
|
x = block(x, cos_sin, kv_cache)
|
||||||
|
x = norm(x)
|
||||||
|
softcap = 15
|
||||||
|
logits = self.lm_head(x)
|
||||||
|
logits = softcap * torch.tanh(logits / softcap)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Simple tokenizer wrapper
|
||||||
|
|
||||||
|
class SimpleTokenizer:
|
||||||
|
def __init__(self, enc):
|
||||||
|
self.enc = enc
|
||||||
|
try:
|
||||||
|
self.bos_token_id = enc.encode_single_token("<|bos|>")
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
self.bos_token_id = enc.encode_single_token("<|endoftext|>")
|
||||||
|
except:
|
||||||
|
self.bos_token_id = 0
|
||||||
|
|
||||||
|
# Get special tokens
|
||||||
|
try:
|
||||||
|
self.user_start = enc.encode_single_token("<|user_start|>")
|
||||||
|
self.user_end = enc.encode_single_token("<|user_end|>")
|
||||||
|
self.assistant_start = enc.encode_single_token("<|assistant_start|>")
|
||||||
|
self.assistant_end = enc.encode_single_token("<|assistant_end|>")
|
||||||
|
except:
|
||||||
|
# Fallback if special tokens don't exist
|
||||||
|
self.user_start = 0
|
||||||
|
self.user_end = 0
|
||||||
|
self.assistant_start = 0
|
||||||
|
self.assistant_end = 0
|
||||||
|
|
||||||
|
def get_bos_token_id(self):
|
||||||
|
return self.bos_token_id
|
||||||
|
|
||||||
|
def encode_special(self, token):
|
||||||
|
try:
|
||||||
|
return self.enc.encode_single_token(token)
|
||||||
|
except:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
return self.enc.encode_ordinary(text)
|
||||||
|
|
||||||
|
def decode(self, tokens):
|
||||||
|
return self.enc.decode(tokens)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Simple generator (no Engine class needed)
|
||||||
|
|
||||||
|
def generate_tokens(model, input_tokens, max_tokens=512, temperature=0.8, top_k=50, device='cpu'):
|
||||||
|
"""Generate tokens one at a time."""
|
||||||
|
x = torch.tensor([input_tokens], dtype=torch.long, device=device)
|
||||||
|
generated = []
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
for _ in range(max_tokens):
|
||||||
|
logits = model(x)
|
||||||
|
logits = logits[:, -1, :] / temperature
|
||||||
|
|
||||||
|
if top_k > 0:
|
||||||
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||||
|
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||||
|
|
||||||
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
|
generated.append(next_token.item())
|
||||||
|
x = torch.cat([x, next_token], dim=1)
|
||||||
|
|
||||||
|
yield next_token.item()
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# FastAPI app
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='NanoChat Web Server (CPU)')
|
||||||
|
parser.add_argument('--model-dir', type=str, required=True, help='Path to model directory containing model_*.pt, meta_*.json, and tokenizer.pkl')
|
||||||
|
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
|
||||||
|
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
|
||||||
|
parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
|
||||||
|
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
||||||
|
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
# OpenAI-compatible request/response models
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: Literal["system", "user", "assistant"]
|
||||||
|
content: str # Only text content supported
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
model: str = Field(default="nanochat", description="Model to use for completion")
|
||||||
|
messages: List[ChatMessage]
|
||||||
|
# Supported parameters
|
||||||
|
temperature: Optional[float] = Field(default=None, ge=0, le=2)
|
||||||
|
max_tokens: Optional[int] = Field(default=None, ge=1)
|
||||||
|
top_k: Optional[int] = Field(default=None, ge=1, description="Top-k sampling (NanoChat-specific)")
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
# Accepted but not supported (will be rejected if provided)
|
||||||
|
top_p: Optional[float] = Field(default=None, ge=0, le=1)
|
||||||
|
n: Optional[int] = Field(default=None, ge=1)
|
||||||
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
|
presence_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
|
||||||
|
frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None
|
||||||
|
user: Optional[str] = None
|
||||||
|
# Not supported features
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None
|
||||||
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
|
||||||
|
functions: Optional[List[Dict[str, Any]]] = None
|
||||||
|
function_call: Optional[Union[str, Dict[str, Any]]] = None
|
||||||
|
|
||||||
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
message: ChatMessage
|
||||||
|
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
||||||
|
|
||||||
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
delta: Dict[str, Any]
|
||||||
|
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
||||||
|
|
||||||
|
class UsageInfo(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
class ChatCompletionResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: Literal["chat.completion"] = "chat.completion"
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionResponseChoice]
|
||||||
|
usage: UsageInfo
|
||||||
|
|
||||||
|
class ChatCompletionStreamResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionResponseStreamChoice]
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Load model on startup."""
|
||||||
|
print(f"Loading model from {args.model_dir}...")
|
||||||
|
|
||||||
|
# Find model and meta files
|
||||||
|
model_files = glob.glob(os.path.join(args.model_dir, "model_*.pt"))
|
||||||
|
if not model_files:
|
||||||
|
raise FileNotFoundError(f"No model files found in {args.model_dir}")
|
||||||
|
model_file = model_files[0]
|
||||||
|
|
||||||
|
meta_files = glob.glob(os.path.join(args.model_dir, "meta_*.json"))
|
||||||
|
if not meta_files:
|
||||||
|
raise FileNotFoundError(f"No meta files found in {args.model_dir}")
|
||||||
|
meta_file = meta_files[0]
|
||||||
|
|
||||||
|
# Load metadata
|
||||||
|
with open(meta_file, 'r') as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
model_config_kwargs = meta["model_config"]
|
||||||
|
print(f"Model config: {model_config_kwargs}")
|
||||||
|
|
||||||
|
# Build the model
|
||||||
|
model_config = GPTConfig(**model_config_kwargs)
|
||||||
|
with torch.device("meta"):
|
||||||
|
model = GPT(model_config)
|
||||||
|
|
||||||
|
# Load model weights
|
||||||
|
print("Loading model weights...")
|
||||||
|
model_data = torch.load(model_file, map_location=device, weights_only=False)
|
||||||
|
model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()}
|
||||||
|
|
||||||
|
# Convert bfloat16 to float32 for CPU
|
||||||
|
print("Converting model to float32 for CPU...")
|
||||||
|
model_data = {k: v.float() if v.dtype == torch.bfloat16 else v for k, v in model_data.items()}
|
||||||
|
|
||||||
|
model.to_empty(device=device)
|
||||||
|
model.init_weights()
|
||||||
|
model.load_state_dict(model_data, strict=True, assign=True)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Load tokenizer
|
||||||
|
print("Loading tokenizer...")
|
||||||
|
tokenizer_path = os.path.join(args.model_dir, "tokenizer.pkl")
|
||||||
|
if not os.path.exists(tokenizer_path):
|
||||||
|
raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}")
|
||||||
|
|
||||||
|
with open(tokenizer_path, "rb") as f:
|
||||||
|
import tiktoken
|
||||||
|
enc = pickle.load(f)
|
||||||
|
|
||||||
|
tokenizer = SimpleTokenizer(enc)
|
||||||
|
|
||||||
|
app.state.model = model
|
||||||
|
app.state.tokenizer = tokenizer
|
||||||
|
|
||||||
|
print(f"✓ Model loaded successfully!")
|
||||||
|
print(f"✓ Server ready at http://localhost:{args.port}")
|
||||||
|
yield
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
# Custom exception handler for OpenAI-compatible error responses
|
||||||
|
class OpenAIError(Exception):
|
||||||
|
"""Custom exception that returns OpenAI-compatible error format."""
|
||||||
|
def __init__(self, message: str, error_type: str = "invalid_request_error", param: str = None, code: str = None):
|
||||||
|
self.message = message
|
||||||
|
self.error_type = error_type
|
||||||
|
self.param = param
|
||||||
|
self.code = code
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
@app.exception_handler(OpenAIError)
|
||||||
|
async def openai_error_handler(request: Request, exc: OpenAIError):
|
||||||
|
"""Return errors in OpenAI API format."""
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"message": exc.message,
|
||||||
|
"type": exc.error_type,
|
||||||
|
"param": exc.param,
|
||||||
|
"code": exc.code
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
|
"""Handle Pydantic validation errors in OpenAI format."""
|
||||||
|
errors = exc.errors()
|
||||||
|
if errors:
|
||||||
|
first_error = errors[0]
|
||||||
|
param = ".".join(str(x) for x in first_error.get("loc", []))
|
||||||
|
message = first_error.get("msg", "Invalid request")
|
||||||
|
else:
|
||||||
|
param = None
|
||||||
|
message = "Invalid request"
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"message": message,
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"param": param,
|
||||||
|
"code": None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""Serve the chat UI."""
|
||||||
|
ui_html_path = os.path.join("nanochat", "ui.html")
|
||||||
|
with open(ui_html_path, "r") as f:
|
||||||
|
html_content = f.read()
|
||||||
|
# Replace the API_URL to use the same origin
|
||||||
|
html_content = html_content.replace(
|
||||||
|
"const API_URL = `http://${window.location.hostname}:8000`;",
|
||||||
|
"const API_URL = '';"
|
||||||
|
)
|
||||||
|
return HTMLResponse(content=html_content)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/logo.svg")
|
||||||
|
async def logo():
|
||||||
|
"""Serve the NanoChat logo for favicon and header."""
|
||||||
|
logo_path = os.path.join("nanochat", "logo.svg")
|
||||||
|
return FileResponse(logo_path, media_type="image/svg+xml")
|
||||||
|
|
||||||
|
async def generate_stream(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
tokens,
|
||||||
|
completion_id: str,
|
||||||
|
model_name: str,
|
||||||
|
created: int,
|
||||||
|
temperature=None,
|
||||||
|
max_new_tokens=None,
|
||||||
|
top_k=None
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Generate assistant response with OpenAI-compatible streaming.
|
||||||
|
|
||||||
|
Supported parameters: temperature, max_new_tokens, top_k
|
||||||
|
"""
|
||||||
|
temperature = temperature if temperature is not None else args.temperature
|
||||||
|
# Greedy decoding when temperature <= 0
|
||||||
|
if temperature is not None and temperature <= 0:
|
||||||
|
temperature = 1e-6
|
||||||
|
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
|
||||||
|
# Enforce max 1000 cap
|
||||||
|
if max_new_tokens is None:
|
||||||
|
max_new_tokens = 256
|
||||||
|
max_new_tokens = max(1, min(1000, int(max_new_tokens)))
|
||||||
|
top_k = top_k if top_k is not None else args.top_k
|
||||||
|
if top_k is None:
|
||||||
|
top_k = 50
|
||||||
|
vocab_size = getattr(app.state.model.config, 'vocab_size', 50257)
|
||||||
|
top_k = max(1, min(int(top_k), int(vocab_size)))
|
||||||
|
|
||||||
|
assistant_end = tokenizer.encode_special("<|assistant_end|>")
|
||||||
|
bos = tokenizer.get_bos_token_id()
|
||||||
|
|
||||||
|
# Send initial chunk with role
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=completion_id,
|
||||||
|
created=created,
|
||||||
|
model=model_name,
|
||||||
|
choices=[ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta={"role": "assistant", "content": ""},
|
||||||
|
finish_reason=None
|
||||||
|
)]
|
||||||
|
)
|
||||||
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
finish_reason = "length"
|
||||||
|
for token in generate_tokens(model, tokens, max_new_tokens, temperature, top_k, device):
|
||||||
|
if token == assistant_end or token == bos:
|
||||||
|
finish_reason = "stop"
|
||||||
|
break
|
||||||
|
|
||||||
|
token_text = tokenizer.decode([token])
|
||||||
|
|
||||||
|
# Send content chunk
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=completion_id,
|
||||||
|
created=created,
|
||||||
|
model=model_name,
|
||||||
|
choices=[ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta={"content": token_text},
|
||||||
|
finish_reason=None
|
||||||
|
)]
|
||||||
|
)
|
||||||
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
# Send final chunk with finish_reason
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=completion_id,
|
||||||
|
created=created,
|
||||||
|
model=model_name,
|
||||||
|
choices=[ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta={},
|
||||||
|
finish_reason=finish_reason
|
||||||
|
)]
|
||||||
|
)
|
||||||
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
# OpenAI sends [DONE] at the end
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
@app.post("/chat/completions")
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def chat_completions(request: ChatCompletionRequest):
|
||||||
|
"""
|
||||||
|
OpenAI-compatible chat completion endpoint.
|
||||||
|
|
||||||
|
Supported parameters:
|
||||||
|
- messages: Array of message objects (text only)
|
||||||
|
- temperature: Sampling temperature (0-2)
|
||||||
|
- max_tokens: Maximum tokens to generate
|
||||||
|
- top_k: Top-k sampling (NanoChat-specific)
|
||||||
|
- stream: Enable streaming responses
|
||||||
|
|
||||||
|
Not supported (rejected with clear errors):
|
||||||
|
- top_p, n, stop, presence_penalty, frequency_penalty, logit_bias, user
|
||||||
|
- tools, functions (function calling not supported)
|
||||||
|
- Multi-modal content (only text messages supported)
|
||||||
|
"""
|
||||||
|
model = app.state.model
|
||||||
|
tokenizer = app.state.tokenizer
|
||||||
|
|
||||||
|
# Validate unsupported features
|
||||||
|
if request.tools or request.tool_choice or request.functions or request.function_call:
|
||||||
|
raise OpenAIError(
|
||||||
|
message="Function calling and tools are not supported by this model. Only text completion is available.",
|
||||||
|
error_type="invalid_request_error",
|
||||||
|
code="unsupported_feature"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reject any unsupported standard params if provided
|
||||||
|
unsupported_fields = []
|
||||||
|
if request.n is not None:
|
||||||
|
unsupported_fields.append("n")
|
||||||
|
if request.top_p is not None:
|
||||||
|
unsupported_fields.append("top_p")
|
||||||
|
if request.stop is not None:
|
||||||
|
unsupported_fields.append("stop")
|
||||||
|
if request.presence_penalty is not None:
|
||||||
|
unsupported_fields.append("presence_penalty")
|
||||||
|
if request.frequency_penalty is not None:
|
||||||
|
unsupported_fields.append("frequency_penalty")
|
||||||
|
if request.logit_bias is not None:
|
||||||
|
unsupported_fields.append("logit_bias")
|
||||||
|
if request.user is not None:
|
||||||
|
unsupported_fields.append("user")
|
||||||
|
|
||||||
|
if unsupported_fields:
|
||||||
|
raise OpenAIError(
|
||||||
|
message=f"Unsupported parameters for this model: {', '.join(unsupported_fields)}. Supported only: messages, temperature, max_tokens, top_k, stream.",
|
||||||
|
error_type="invalid_request_error",
|
||||||
|
param=unsupported_fields[0],
|
||||||
|
code="unsupported_parameter"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate messages are text-only
|
||||||
|
for i, msg in enumerate(request.messages):
|
||||||
|
if not isinstance(msg.content, str):
|
||||||
|
raise OpenAIError(
|
||||||
|
message=f"Message at index {i} contains non-text content. Only text messages are supported.",
|
||||||
|
error_type="invalid_request_error",
|
||||||
|
param=f"messages[{i}].content",
|
||||||
|
code="invalid_message_content"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate unique completion ID and timestamp
|
||||||
|
completion_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||||
|
created = int(time.time())
|
||||||
|
model_name = request.model
|
||||||
|
|
||||||
|
# Build conversation tokens
|
||||||
|
bos = tokenizer.get_bos_token_id()
|
||||||
|
user_start = tokenizer.encode_special("<|user_start|>")
|
||||||
|
user_end = tokenizer.encode_special("<|user_end|>")
|
||||||
|
assistant_start = tokenizer.encode_special("<|assistant_start|>")
|
||||||
|
assistant_end = tokenizer.encode_special("<|assistant_end|>")
|
||||||
|
system_start = tokenizer.encode_special("<|system_start|>")
|
||||||
|
system_end = tokenizer.encode_special("<|system_end|>")
|
||||||
|
|
||||||
|
conversation_tokens = [bos]
|
||||||
|
for message in request.messages:
|
||||||
|
if message.role == "user":
|
||||||
|
conversation_tokens.append(user_start)
|
||||||
|
conversation_tokens.extend(tokenizer.encode(message.content))
|
||||||
|
conversation_tokens.append(user_end)
|
||||||
|
elif message.role == "assistant":
|
||||||
|
conversation_tokens.append(assistant_start)
|
||||||
|
conversation_tokens.extend(tokenizer.encode(message.content))
|
||||||
|
conversation_tokens.append(assistant_end)
|
||||||
|
elif message.role == "system":
|
||||||
|
# Handle system messages if supported
|
||||||
|
if system_start != 0 and system_end != 0:
|
||||||
|
conversation_tokens.append(system_start)
|
||||||
|
conversation_tokens.extend(tokenizer.encode(message.content))
|
||||||
|
conversation_tokens.append(system_end)
|
||||||
|
else:
|
||||||
|
# Fallback: treat system message as user message
|
||||||
|
conversation_tokens.append(user_start)
|
||||||
|
conversation_tokens.extend(tokenizer.encode(message.content))
|
||||||
|
conversation_tokens.append(user_end)
|
||||||
|
|
||||||
|
conversation_tokens.append(assistant_start)
|
||||||
|
prompt_tokens = len(conversation_tokens)
|
||||||
|
|
||||||
|
# Use only supported parameters: temperature, max_tokens, top_k
|
||||||
|
if request.stream:
|
||||||
|
return StreamingResponse(
|
||||||
|
generate_stream(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
conversation_tokens,
|
||||||
|
completion_id=completion_id,
|
||||||
|
model_name=model_name,
|
||||||
|
created=created,
|
||||||
|
temperature=request.temperature,
|
||||||
|
max_new_tokens=request.max_tokens,
|
||||||
|
top_k=request.top_k
|
||||||
|
),
|
||||||
|
media_type="text/event-stream"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Non-streaming response
|
||||||
|
temperature = request.temperature if request.temperature is not None else args.temperature
|
||||||
|
# Enforce max 1000 tokens cap
|
||||||
|
max_tokens = request.max_tokens if request.max_tokens is not None else args.max_tokens
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = 256
|
||||||
|
max_tokens = max(1, min(1000, int(max_tokens)))
|
||||||
|
# Validate top_k: 1..vocab_size
|
||||||
|
top_k = request.top_k if request.top_k is not None else args.top_k
|
||||||
|
if top_k is None:
|
||||||
|
top_k = 50
|
||||||
|
vocab_size = getattr(app.state.model.config, 'vocab_size', 50257)
|
||||||
|
top_k = max(1, min(int(top_k), int(vocab_size)))
|
||||||
|
|
||||||
|
generated_tokens = []
|
||||||
|
finish_reason = "length"
|
||||||
|
|
||||||
|
for token in generate_tokens(model, conversation_tokens, max_tokens, temperature, top_k, device):
|
||||||
|
if token == assistant_end or token == bos:
|
||||||
|
finish_reason = "stop"
|
||||||
|
break
|
||||||
|
generated_tokens.append(token)
|
||||||
|
|
||||||
|
response_text = tokenizer.decode(generated_tokens)
|
||||||
|
completion_tokens = len(generated_tokens)
|
||||||
|
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
id=completion_id,
|
||||||
|
created=created,
|
||||||
|
model=model_name,
|
||||||
|
choices=[ChatCompletionResponseChoice(
|
||||||
|
index=0,
|
||||||
|
message=ChatMessage(role="assistant", content=response_text),
|
||||||
|
finish_reason=finish_reason
|
||||||
|
)],
|
||||||
|
usage=UsageInfo(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/v1/models")
|
||||||
|
@app.get("/models")
|
||||||
|
async def list_models():
|
||||||
|
"""
|
||||||
|
List available models (OpenAI-compatible endpoint).
|
||||||
|
|
||||||
|
Returns model information with capabilities annotation.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": "nanochat",
|
||||||
|
"object": "model",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"owned_by": "nanochat",
|
||||||
|
"permission": [],
|
||||||
|
"root": "nanochat",
|
||||||
|
"parent": None
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
"""Health check endpoint."""
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"ready": hasattr(app.state, 'model') and app.state.model is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
print(f"Starting NanoChat Web Server (CPU mode)")
|
||||||
|
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
|
||||||
|
uvicorn.run(app, host=args.host, port=args.port)
|
||||||
|
|
||||||
Loading…
Reference in New Issue
Block a user