mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-17 13:17:35 +00:00
* feat(inference): deploy d24 SFT weights to Modal Repoint Modal inference app from the broken d20 checkpoint to our own ManmohanSharma/nanochat-d24 SFT step 484. Rewrites the standalone model as an inference-only port of nanochat/gpt.py so the modern architecture (smear gate, per-layer value embeddings, ve_gate, backout, sliding window attention via SDPA, rotary base 100000, padded vocab, logit softcap) loads cleanly from the checkpoint. Tokenizer loads the pickled tiktoken encoding directly so special tokens end up at their true IDs (32759-32767), and the stop check uses that set instead of hardcoded 0-8. GPU bumped to L4 for headroom. HF token sourced from the 'huggingface' Modal secret. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat(frontend): polished redesign with serif display + dark mode Lifts the craft level of the landing and chat UI without changing the desi identity. Adds Fraunces for display headlines, a floating pill LandingNav, a saffron-glow hero with a large serif headline and black pill CTAs, and three gradient-tiled feature cards with inline SVG glyphs replacing the emoji cards. The chat empty state is now a serif greeting with pill-chip prompt starters, and ChatInput is a single rounded pod so the send button sits inside the input (fixes the misaligned floating button). Adds a class-based dark mode across the chat surfaces with a sun/moon toggle in the sidebar footer, powered by a small useTheme hook and a no-flash init script in the root layout. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * chore(frontend): add ESLint config so CI lint step passes next lint was failing with an interactive prompt because the repo had no ESLint config. Adds a minimal next/core-web-vitals extends and drops the now-unloadable @typescript-eslint/no-explicit-any disable directive in the stream proxy by narrowing the body type to unknown. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
301 lines
11 KiB
Python
301 lines
11 KiB
Python
"""
|
|
samosaChaat — Modal GPU inference endpoint.
|
|
|
|
Downloads nanochat model weights from HuggingFace into a Modal Volume,
|
|
loads them on a GPU, and exposes an SSE streaming endpoint compatible
|
|
with the samosaChaat chat-api service.
|
|
|
|
Deploy: modal deploy modal/serve.py
|
|
Dev: modal serve modal/serve.py
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import time
|
|
|
|
import modal
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Configuration
|
|
# ---------------------------------------------------------------------------
|
|
MODEL_REPO = "ManmohanSharma/nanochat-d24"
|
|
MODEL_PT = "chatsft_checkpoints/d24/model_000484.pt"
|
|
META_JSON = "chatsft_checkpoints/d24/meta_000484.json"
|
|
TOKENIZER_PKL = "tokenizer/tokenizer.pkl"
|
|
TOKEN_BYTES = "tokenizer/token_bytes.pt"
|
|
MODEL_TAG = "d24-sft"
|
|
GPU_TYPE = "L4" # 24 GB VRAM — fits 4 GB bf16 model loaded as fp32
|
|
VOLUME_NAME = "samosachaat-weights"
|
|
HF_SECRET_NAME = "huggingface" # Modal secret containing HF_TOKEN
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Modal app + image
|
|
# ---------------------------------------------------------------------------
|
|
app = modal.App("samosachaat-inference")
|
|
|
|
# Build the container image with all dependencies
|
|
inference_image = (
|
|
modal.Image.debian_slim(python_version="3.12")
|
|
.pip_install(
|
|
"torch==2.5.1",
|
|
"tiktoken>=0.11.0",
|
|
"tokenizers>=0.22.0",
|
|
"huggingface_hub>=0.25.0",
|
|
"fastapi>=0.115.0",
|
|
"uvicorn>=0.30.0",
|
|
extra_index_url="https://download.pytorch.org/whl/cu124",
|
|
)
|
|
.add_local_file("modal/_model.py", "/root/_model.py")
|
|
.add_local_file("modal/_tokenizer.py", "/root/_tokenizer.py")
|
|
)
|
|
|
|
# Persistent volume for model weights
|
|
volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Download weights into the volume (runs once)
|
|
# ---------------------------------------------------------------------------
|
|
@app.function(
|
|
image=inference_image,
|
|
volumes={"/weights": volume},
|
|
secrets=[modal.Secret.from_name(HF_SECRET_NAME)],
|
|
timeout=1800,
|
|
)
|
|
def download_weights():
|
|
"""Download model weights from HuggingFace into the Modal volume."""
|
|
import shutil
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
model_dir = f"/weights/{MODEL_TAG}"
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
token = os.environ.get("HF_TOKEN")
|
|
|
|
# (HF source path, local filename in volume)
|
|
files = [
|
|
(MODEL_PT, "model.pt"),
|
|
(META_JSON, "meta.json"),
|
|
(TOKENIZER_PKL, "tokenizer.pkl"),
|
|
(TOKEN_BYTES, "token_bytes.pt"),
|
|
]
|
|
|
|
for src, local_name in files:
|
|
dest = os.path.join(model_dir, local_name)
|
|
if os.path.exists(dest):
|
|
print(f" Already exists: {dest}")
|
|
continue
|
|
print(f" Downloading {src} from {MODEL_REPO}...")
|
|
path = hf_hub_download(MODEL_REPO, src, token=token)
|
|
shutil.copy2(path, dest)
|
|
print(f" Saved to {dest}")
|
|
|
|
volume.commit()
|
|
print("Weights downloaded and committed to volume.")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Inference class — GPU singleton
|
|
# ---------------------------------------------------------------------------
|
|
@app.cls(
|
|
image=inference_image,
|
|
volumes={"/weights": volume},
|
|
gpu=GPU_TYPE,
|
|
scaledown_window=300, # keep warm for 5 min after last request
|
|
# concurrency handled by @modal.concurrent below
|
|
timeout=120,
|
|
)
|
|
class Inference:
|
|
model: object
|
|
tokenizer: object
|
|
engine: object
|
|
device: object
|
|
|
|
@modal.enter()
|
|
def load_model(self):
|
|
"""Called once when the container starts — loads model onto GPU."""
|
|
import torch
|
|
import sys
|
|
|
|
# Add the nanochat engine code path
|
|
# We inline the minimal loading logic here to avoid importing the full
|
|
# nanochat package (which has heavy deps we don't need on Modal).
|
|
print("Loading model...")
|
|
t0 = time.time()
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.device = device
|
|
|
|
model_dir = f"/weights/{MODEL_TAG}"
|
|
meta_path = os.path.join(model_dir, "meta.json")
|
|
model_path = os.path.join(model_dir, "model.pt")
|
|
|
|
# Load meta
|
|
with open(meta_path) as f:
|
|
meta = json.load(f)
|
|
model_config = meta if "model_config" not in meta else meta["model_config"]
|
|
|
|
# Normalize config key names (HF format → nanochat format)
|
|
# Map HF config keys → nanochat GPTConfig keys
|
|
seq_len = model_config.pop("n_positions", None) or model_config.pop("n_ctx", None)
|
|
if seq_len and "sequence_len" not in model_config:
|
|
model_config["sequence_len"] = seq_len
|
|
# Also remove n_ctx if sequence_len was already set
|
|
model_config.pop("n_ctx", None)
|
|
model_config.pop("n_positions", None)
|
|
# Remove HF-specific keys that GPTConfig doesn't accept
|
|
for k in ["architectures", "model_type", "rotary", "rotary_base", "tie_word_embeddings"]:
|
|
model_config.pop(k, None)
|
|
|
|
# Patch missing config keys
|
|
model_config.setdefault("window_pattern", "L")
|
|
|
|
print(f" Config: {model_config}")
|
|
|
|
# Build model
|
|
# We need the GPT class — download it from the repo itself
|
|
# For simplicity, we define a minimal inline version that matches nanochat
|
|
from _model import GPT, GPTConfig
|
|
|
|
config = GPTConfig(**model_config)
|
|
model_data = torch.load(model_path, map_location=device, weights_only=False)
|
|
|
|
# Strip torch.compile prefix
|
|
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
|
|
|
# Convert bfloat16 weights to float32 for compatibility on non-Hopper GPUs
|
|
model_data = {
|
|
k: v.float() if v.dtype == torch.bfloat16 else v
|
|
for k, v in model_data.items()
|
|
}
|
|
|
|
model = GPT.from_state_dict(config, model_data)
|
|
model.load_state_dict(model_data, strict=True, assign=True)
|
|
model.to(device)
|
|
model.init_rotary(device=device, dtype=torch.float32)
|
|
model.eval()
|
|
|
|
self.model = model
|
|
self.config = config
|
|
|
|
# Load tokenizer
|
|
from _tokenizer import get_tokenizer, SPECIAL_TOKENS
|
|
self.tokenizer = get_tokenizer(model_dir)
|
|
|
|
# Resolve actual special-token IDs (nanochat appends specials at end of vocab)
|
|
self.special_token_ids = set()
|
|
for name in SPECIAL_TOKENS:
|
|
ids = self.tokenizer.encode_special(name)
|
|
self.special_token_ids.update(ids)
|
|
self.assistant_end_id = self.tokenizer.encode_special("<|assistant_end|>")[0]
|
|
print(f" Special token IDs: {sorted(self.special_token_ids)}")
|
|
|
|
dt = time.time() - t0
|
|
print(f"Model loaded in {dt:.1f}s on {device}")
|
|
|
|
@modal.fastapi_endpoint(method="POST", docs=True)
|
|
async def generate(self, request: dict):
|
|
"""
|
|
Streaming chat endpoint — SSE compatible with samosaChaat format.
|
|
|
|
Input: {"messages": [{"role": "user", "content": "..."}], "temperature": 0.8, "max_tokens": 512, "top_k": 50}
|
|
Output: SSE stream of data: {"token": "...", "gpu": 0} then data: {"done": true}
|
|
"""
|
|
import torch
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
messages = request.get("messages", [])
|
|
temperature = min(max(request.get("temperature", 0.8), 0.0), 2.0)
|
|
max_tokens = min(max(request.get("max_tokens", 512), 1), 2048)
|
|
top_k = min(max(request.get("top_k", 50), 0), 200)
|
|
|
|
# Build token sequence from messages
|
|
tokens = []
|
|
bos = self.tokenizer.encode_special("<|bos|>")
|
|
user_start = self.tokenizer.encode_special("<|user_start|>")
|
|
user_end = self.tokenizer.encode_special("<|user_end|>")
|
|
assistant_start = self.tokenizer.encode_special("<|assistant_start|>")
|
|
assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
|
|
|
|
tokens.extend(bos)
|
|
for msg in messages:
|
|
if msg["role"] == "user":
|
|
tokens.extend(user_start)
|
|
tokens.extend(self.tokenizer.encode(msg["content"]))
|
|
tokens.extend(user_end)
|
|
elif msg["role"] == "assistant":
|
|
tokens.extend(assistant_start)
|
|
tokens.extend(self.tokenizer.encode(msg["content"]))
|
|
tokens.extend(assistant_end)
|
|
# Prompt the model to generate an assistant response
|
|
tokens.extend(assistant_start)
|
|
|
|
# Truncate to fit context
|
|
max_context = self.config.sequence_len - max_tokens
|
|
if len(tokens) > max_context:
|
|
tokens = tokens[-max_context:]
|
|
|
|
async def stream():
|
|
input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device)
|
|
|
|
with torch.no_grad():
|
|
generated = []
|
|
for _ in range(max_tokens):
|
|
# Forward pass
|
|
logits = self.model(input_ids)
|
|
next_logits = logits[:, -1, :]
|
|
|
|
# Temperature
|
|
if temperature > 0:
|
|
next_logits = next_logits / temperature
|
|
|
|
# Top-k filtering
|
|
if top_k > 0:
|
|
v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
|
|
next_logits[next_logits < v[:, [-1]]] = float('-inf')
|
|
|
|
# Sample
|
|
probs = torch.softmax(next_logits, dim=-1)
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
token_id = next_token.item()
|
|
|
|
# Stop on any special token (assistant_end, bos, etc.)
|
|
if token_id in self.special_token_ids:
|
|
break
|
|
|
|
# Decode and yield (skip tokens that can't be decoded)
|
|
try:
|
|
token_text = self.tokenizer.decode([token_id])
|
|
except (KeyError, Exception):
|
|
continue
|
|
yield f"data: {json.dumps({'token': token_text, 'gpu': 0})}\n\n"
|
|
|
|
# Append for next iteration
|
|
input_ids = torch.cat([input_ids, next_token], dim=1)
|
|
|
|
# Truncate if exceeding sequence length
|
|
if input_ids.size(1) > self.config.sequence_len:
|
|
input_ids = input_ids[:, -self.config.sequence_len:]
|
|
|
|
yield f"data: {json.dumps({'done': True})}\n\n"
|
|
|
|
return StreamingResponse(
|
|
stream(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
@modal.fastapi_endpoint(method="GET", docs=True)
|
|
def health(self):
|
|
return {
|
|
"status": "ok",
|
|
"model": MODEL_TAG,
|
|
"gpu": GPU_TYPE,
|
|
"ready": self.model is not None,
|
|
}
|