diff --git a/scripts/chat_web_cpu.py b/scripts/chat_web_cpu.py new file mode 100644 index 0000000..0d7ed82 --- /dev/null +++ b/scripts/chat_web_cpu.py @@ -0,0 +1,479 @@ +#!/usr/bin/env python3 +""" +CPU-compatible web chat server - serves both UI and API from a single FastAPI instance. +Run with: python chat_web_cpu.py --model-dir /path/to/model +Then open http://localhost:8000 in your browser. +""" + +import argparse +import json +import os +import glob +import pickle +import math +import torch +from contextlib import asynccontextmanager +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse +from pydantic import BaseModel +from typing import List, Optional, AsyncGenerator +from dataclasses import dataclass + +import torch.nn as nn +import torch.nn.functional as F + +# ----------------------------------------------------------------------------- +# Minimal GPT implementation (copied from generate_cpu.py) + +@dataclass +class GPTConfig: + sequence_len: int = 1024 + vocab_size: int = 50304 + n_layer: int = 12 + n_head: int = 6 + n_kv_head: int = 6 + n_embd: int = 768 + + +def norm(x): + return F.rms_norm(x, (x.size(-1),)) + + +def apply_rotary_emb(x, cos, sin): + assert x.ndim == 4 + d = x.shape[3] // 2 + x1, x2 = x[..., :d], x[..., d:] + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + out = torch.cat([y1, y2], 3) + out = out.to(x.dtype) + return out + + +def repeat_kv(x, n_rep): + if n_rep == 1: + return x + bs, n_kv_heads, slen, head_dim = x.shape + return ( + x[:, :, None, :, :] + .expand(bs, n_kv_heads, n_rep, slen, head_dim) + .reshape(bs, n_kv_heads * n_rep, slen, head_dim) + ) + + +class CausalSelfAttention(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.layer_idx = layer_idx + self.n_head = config.n_head + self.n_kv_head = config.n_kv_head + self.n_embd = config.n_embd + self.head_dim = self.n_embd // self.n_head + assert self.n_embd % self.n_head == 0 + assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 + self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) + self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) + self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) + self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) + + def forward(self, x, cos_sin, kv_cache): + B, T, C = x.size() + q = self.c_q(x).view(B, T, self.n_head, self.head_dim) + k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) + v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) + cos, sin = cos_sin + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + q, k = norm(q), norm(k) + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + if kv_cache is not None: + k, v = kv_cache.insert_kv(self.layer_idx, k, v) + Tq = q.size(2) + Tk = k.size(2) + nrep = self.n_head // self.n_kv_head + k, v = repeat_kv(k, nrep), repeat_kv(v, nrep) + if kv_cache is None or Tq == Tk: + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + elif Tq == 1: + y = F.scaled_dot_product_attention(q, k, v, is_causal=False) + else: + attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) + prefix_len = Tk - Tq + if prefix_len > 0: + attn_mask[:, :prefix_len] = True + attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + y = y.transpose(1, 2).contiguous().view(B, T, -1) + y = self.c_proj(y) + return y + + +class MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) + + def forward(self, x): + x = self.c_fc(x) + x = F.relu(x).square() + x = self.c_proj(x) + return x + + +class Block(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.attn = CausalSelfAttention(config, layer_idx) + self.mlp = MLP(config) + + def forward(self, x, cos_sin, kv_cache): + x = x + self.attn(norm(x), cos_sin, kv_cache) + x = x + self.mlp(norm(x)) + return x + + +class GPT(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.transformer = nn.ModuleDict({ + "wte": nn.Embedding(config.vocab_size, config.n_embd), + "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), + }) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.rotary_seq_len = config.sequence_len * 10 + head_dim = config.n_embd // config.n_head + cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) + self.register_buffer("cos", cos, persistent=False) + self.register_buffer("sin", sin, persistent=False) + + def init_weights(self): + self.apply(self._init_weights) + torch.nn.init.zeros_(self.lm_head.weight) + for block in self.transformer.h: + torch.nn.init.zeros_(block.mlp.c_proj.weight) + torch.nn.init.zeros_(block.attn.c_proj.weight) + head_dim = self.config.n_embd // self.config.n_head + cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) + self.cos, self.sin = cos, sin + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + fan_out = module.weight.size(0) + fan_in = module.weight.size(1) + std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in)) + torch.nn.init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=1.0) + + def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): + if device is None: + device = self.transformer.wte.weight.device + channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) + inv_freq = 1.0 / (base ** (channel_range / head_dim)) + t = torch.arange(seq_len, dtype=torch.float32, device=device) + freqs = torch.outer(t, inv_freq) + cos, sin = freqs.cos(), freqs.sin() + cos, sin = cos[None, :, None, :], sin[None, :, None, :] + return cos, sin + + def forward(self, idx, targets=None, kv_cache=None): + B, T = idx.size() + assert T <= self.cos.size(1) + T0 = 0 if kv_cache is None else kv_cache.get_pos() + cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] + x = self.transformer.wte(idx) + x = norm(x) + for block in self.transformer.h: + x = block(x, cos_sin, kv_cache) + x = norm(x) + softcap = 15 + logits = self.lm_head(x) + logits = softcap * torch.tanh(logits / softcap) + return logits + + +# ----------------------------------------------------------------------------- +# Simple tokenizer wrapper + +class SimpleTokenizer: + def __init__(self, enc): + self.enc = enc + try: + self.bos_token_id = enc.encode_single_token("<|bos|>") + except: + try: + self.bos_token_id = enc.encode_single_token("<|endoftext|>") + except: + self.bos_token_id = 0 + + # Get special tokens + try: + self.user_start = enc.encode_single_token("<|user_start|>") + self.user_end = enc.encode_single_token("<|user_end|>") + self.assistant_start = enc.encode_single_token("<|assistant_start|>") + self.assistant_end = enc.encode_single_token("<|assistant_end|>") + except: + # Fallback if special tokens don't exist + self.user_start = 0 + self.user_end = 0 + self.assistant_start = 0 + self.assistant_end = 0 + + def get_bos_token_id(self): + return self.bos_token_id + + def encode_special(self, token): + try: + return self.enc.encode_single_token(token) + except: + return 0 + + def encode(self, text): + return self.enc.encode_ordinary(text) + + def decode(self, tokens): + return self.enc.decode(tokens) + + +# ----------------------------------------------------------------------------- +# Simple generator (no Engine class needed) + +def generate_tokens(model, input_tokens, max_tokens=512, temperature=0.8, top_k=50, device='cpu'): + """Generate tokens one at a time.""" + x = torch.tensor([input_tokens], dtype=torch.long, device=device) + generated = [] + + with torch.inference_mode(): + for _ in range(max_tokens): + logits = model(x) + logits = logits[:, -1, :] / temperature + + if top_k > 0: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float('Inf') + + probs = torch.nn.functional.softmax(logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + + generated.append(next_token.item()) + x = torch.cat([x, next_token], dim=1) + + yield next_token.item() + + +# ----------------------------------------------------------------------------- +# FastAPI app + +parser = argparse.ArgumentParser(description='NanoChat Web Server (CPU)') +parser.add_argument('--model-dir', type=str, required=True, help='Path to model directory containing model_*.pt, meta_*.json, and tokenizer.pkl') +parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation') +parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter') +parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation') +parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on') +parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') +args = parser.parse_args() + +device = torch.device("cpu") + +class ChatMessage(BaseModel): + role: str + content: str + +class ChatRequest(BaseModel): + messages: List[ChatMessage] + temperature: Optional[float] = None + max_tokens: Optional[int] = None + top_k: Optional[int] = None + stream: Optional[bool] = True + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Load model on startup.""" + print(f"Loading model from {args.model_dir}...") + + # Find model and meta files + model_files = glob.glob(os.path.join(args.model_dir, "model_*.pt")) + if not model_files: + raise FileNotFoundError(f"No model files found in {args.model_dir}") + model_file = model_files[0] + + meta_files = glob.glob(os.path.join(args.model_dir, "meta_*.json")) + if not meta_files: + raise FileNotFoundError(f"No meta files found in {args.model_dir}") + meta_file = meta_files[0] + + # Load metadata + with open(meta_file, 'r') as f: + meta = json.load(f) + + model_config_kwargs = meta["model_config"] + print(f"Model config: {model_config_kwargs}") + + # Build the model + model_config = GPTConfig(**model_config_kwargs) + with torch.device("meta"): + model = GPT(model_config) + + # Load model weights + print("Loading model weights...") + model_data = torch.load(model_file, map_location=device, weights_only=False) + model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()} + + # Convert bfloat16 to float32 for CPU + print("Converting model to float32 for CPU...") + model_data = {k: v.float() if v.dtype == torch.bfloat16 else v for k, v in model_data.items()} + + model.to_empty(device=device) + model.init_weights() + model.load_state_dict(model_data, strict=True, assign=True) + model.eval() + + # Load tokenizer + print("Loading tokenizer...") + tokenizer_path = os.path.join(args.model_dir, "tokenizer.pkl") + if not os.path.exists(tokenizer_path): + raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}") + + with open(tokenizer_path, "rb") as f: + import tiktoken + enc = pickle.load(f) + + tokenizer = SimpleTokenizer(enc) + + app.state.model = model + app.state.tokenizer = tokenizer + + print(f"✓ Model loaded successfully!") + print(f"✓ Server ready at http://localhost:{args.port}") + yield + +app = FastAPI(lifespan=lifespan) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +@app.get("/") +async def root(): + """Serve the chat UI.""" + ui_html_path = os.path.join("nanochat", "ui.html") + with open(ui_html_path, "r") as f: + html_content = f.read() + # Replace the API_URL to use the same origin + html_content = html_content.replace( + "const API_URL = `http://${window.location.hostname}:8000`;", + "const API_URL = '';" + ) + return HTMLResponse(content=html_content) + + +@app.get("/logo.svg") +async def logo(): + """Serve the NanoChat logo for favicon and header.""" + logo_path = os.path.join("nanochat", "logo.svg") + return FileResponse(logo_path, media_type="image/svg+xml") + +async def generate_stream( + model, + tokenizer, + tokens, + temperature=None, + max_new_tokens=None, + top_k=None +) -> AsyncGenerator[str, None]: + """Generate assistant response with streaming.""" + temperature = temperature if temperature is not None else args.temperature + max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens + top_k = top_k if top_k is not None else args.top_k + + assistant_end = tokenizer.encode_special("<|assistant_end|>") + bos = tokenizer.get_bos_token_id() + + for token in generate_tokens(model, tokens, max_new_tokens, temperature, top_k, device): + if token == assistant_end or token == bos: + break + + token_text = tokenizer.decode([token]) + yield f"data: {json.dumps({'token': token_text})}\n\n" + + yield f"data: {json.dumps({'done': True})}\n\n" + +@app.post("/chat/completions") +async def chat_completions(request: ChatRequest): + """Chat completion endpoint with streaming.""" + model = app.state.model + tokenizer = app.state.tokenizer + + # Build conversation tokens + bos = tokenizer.get_bos_token_id() + user_start = tokenizer.encode_special("<|user_start|>") + user_end = tokenizer.encode_special("<|user_end|>") + assistant_start = tokenizer.encode_special("<|assistant_start|>") + assistant_end = tokenizer.encode_special("<|assistant_end|>") + + conversation_tokens = [bos] + for message in request.messages: + if message.role == "user": + conversation_tokens.append(user_start) + conversation_tokens.extend(tokenizer.encode(message.content)) + conversation_tokens.append(user_end) + elif message.role == "assistant": + conversation_tokens.append(assistant_start) + conversation_tokens.extend(tokenizer.encode(message.content)) + conversation_tokens.append(assistant_end) + + conversation_tokens.append(assistant_start) + + if request.stream: + return StreamingResponse( + generate_stream( + model, + tokenizer, + conversation_tokens, + temperature=request.temperature, + max_new_tokens=request.max_tokens, + top_k=request.top_k + ), + media_type="text/event-stream" + ) + else: + # Non-streaming response + temperature = request.temperature if request.temperature is not None else args.temperature + max_tokens = request.max_tokens if request.max_tokens is not None else args.max_tokens + top_k = request.top_k if request.top_k is not None else args.top_k + + generated_tokens = list(generate_tokens(model, conversation_tokens, max_tokens, temperature, top_k, device)) + response_text = tokenizer.decode(generated_tokens) + + return { + "choices": [{ + "message": { + "role": "assistant", + "content": response_text + }, + "finish_reason": "stop" + }] + } + +@app.get("/health") +async def health(): + """Health check endpoint.""" + return { + "status": "ok", + "ready": hasattr(app.state, 'model') and app.state.model is not None + } + +if __name__ == "__main__": + import uvicorn + print(f"Starting NanoChat Web Server (CPU mode)") + print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}") + uvicorn.run(app, host=args.host, port=args.port) +