diff --git a/modal/_model.py b/modal/_model.py new file mode 100644 index 00000000..a6a1da47 --- /dev/null +++ b/modal/_model.py @@ -0,0 +1,202 @@ +""" +Minimal standalone GPT model for Modal inference. +Extracted from nanochat/gpt.py — only the forward-pass code needed for inference. +No training, no DDP, no flash_attention dependency. +""" + +from dataclasses import dataclass +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +@dataclass +class GPTConfig: + sequence_len: int = 2048 + vocab_size: int = 65536 + n_layer: int = 20 + n_head: int = 10 + n_kv_head: int = 10 + n_embd: int = 1280 + window_pattern: str = "L" + + +class RMSNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6) + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, max_seq_len=2048): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.max_seq_len = max_seq_len + + def forward(self, x, offset=0): + seq_len = x.shape[-2] + t = torch.arange(offset, offset + seq_len, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos(), emb.sin() + + +def apply_rotary_pos_emb(q, k, cos, sin): + def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class CausalSelfAttention(nn.Module): + def __init__(self, config: GPTConfig, use_v_emb: bool = False): + super().__init__() + self.n_head = config.n_head + self.n_kv_head = config.n_kv_head + self.head_dim = config.n_embd // config.n_head + self.n_embd = config.n_embd + self.use_v_emb = use_v_emb + + self.c_q = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=False) + self.c_k = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=False) + self.c_v = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=False) + self.c_proj = nn.Linear(config.n_head * self.head_dim, config.n_embd, bias=False) + + self.q_norm = RMSNorm(self.head_dim) + self.k_norm = RMSNorm(self.head_dim) + + if use_v_emb: + self.v_emb = nn.Parameter(torch.zeros(1, config.n_kv_head, config.sequence_len, self.head_dim)) + + self.rotary = RotaryEmbedding(self.head_dim, config.sequence_len) + + def forward(self, x): + B, T, C = x.size() + + q = self.c_q(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) + k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) + v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) + + # QK norm + q = self.q_norm(q) + k = self.k_norm(k) + + # Rotary embeddings + cos, sin = self.rotary(q) + q, k = apply_rotary_pos_emb(q, k, cos, sin) + + # GQA: repeat k,v if n_kv_head < n_head + if self.n_kv_head < self.n_head: + rep = self.n_head // self.n_kv_head + k = k.repeat_interleave(rep, dim=1) + v = v.repeat_interleave(rep, dim=1) + + # Value embeddings (if enabled) + if self.use_v_emb: + v = v + self.v_emb[:, :, :T, :] + + # Scaled dot-product attention (PyTorch native, causal) + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + + y = y.transpose(1, 2).contiguous().view(B, T, C) + return self.c_proj(y) + + +class MLP(nn.Module): + def __init__(self, config: GPTConfig, gated: bool = False): + super().__init__() + self.gated = gated + if gated: + hidden = int(config.n_embd * 8 / 3) + hidden = ((hidden + 63) // 64) * 64 + self.c_fc = nn.Linear(config.n_embd, hidden, bias=False) + self.c_fc2 = nn.Linear(config.n_embd, hidden, bias=False) + self.c_proj = nn.Linear(hidden, config.n_embd, bias=False) + else: + hidden = 4 * config.n_embd + self.c_fc = nn.Linear(config.n_embd, hidden, bias=False) + self.c_proj = nn.Linear(hidden, config.n_embd, bias=False) + + def forward(self, x): + if self.gated: + a = self.c_fc(x) + b = self.c_fc2(x) + return self.c_proj(F.relu(a).pow(2) * b) + else: + return self.c_proj(F.relu(self.c_fc(x)).pow(2)) + + +class Block(nn.Module): + def __init__(self, config: GPTConfig, layer_idx: int, gated_mlp: bool = False, use_v_emb: bool = False): + super().__init__() + self.ln_1 = RMSNorm(config.n_embd) + self.attn = CausalSelfAttention(config, use_v_emb=use_v_emb) + self.ln_2 = RMSNorm(config.n_embd) + self.mlp = MLP(config, gated=gated_mlp) + self.layer_idx = layer_idx + + def forward(self, x, resid_lambda=1.0, x0_lambda=0.0, x0=None): + h = x * resid_lambda + self.attn(self.ln_1(x)) + if x0 is not None and x0_lambda != 0.0: + h = h + x0_lambda * x0 + h2 = h * resid_lambda + self.mlp(self.ln_2(h)) + if x0 is not None and x0_lambda != 0.0: + h2 = h2 + x0_lambda * x0 + return h2 + + +class GPT(nn.Module): + def __init__(self, config: GPTConfig, gated_mlp: bool = False, use_v_emb: bool = False): + super().__init__() + self.config = config + + self.transformer = nn.ModuleDict(dict( + wte=nn.Embedding(config.vocab_size, config.n_embd), + norm_emb=RMSNorm(config.n_embd), + h=nn.ModuleList([Block(config, i, gated_mlp=gated_mlp, use_v_emb=use_v_emb) for i in range(config.n_layer)]), + ln_f=RMSNorm(config.n_embd), + )) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Residual lambdas (per-layer scaling) + self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) + self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) + + @classmethod + def from_state_dict(cls, config: GPTConfig, state_dict: dict): + """Auto-detect architecture features from checkpoint keys.""" + gated = any("c_fc2" in k for k in state_dict) + v_emb = any("v_emb" in k for k in state_dict) + model = cls(config, gated_mlp=gated, use_v_emb=v_emb) + return model + + def init_weights(self): + """Initialize rotary embeddings and value embeddings.""" + for module in self.modules(): + if isinstance(module, RotaryEmbedding): + inv_freq = 1.0 / (10000 ** (torch.arange(0, module.inv_freq.shape[0] * 2, 2).float() / (module.inv_freq.shape[0] * 2))) + module.inv_freq.copy_(inv_freq) + + def forward(self, idx): + B, T = idx.size() + assert T <= self.config.sequence_len, f"Input length {T} exceeds max {self.config.sequence_len}" + + x = self.transformer.wte(idx) + x = self.transformer.norm_emb(x) + x0 = x # save for residual connections + + for i, block in enumerate(self.transformer.h): + rl = self.resid_lambdas[i].item() + xl = self.x0_lambdas[i].item() + x = block(x, resid_lambda=rl, x0_lambda=xl, x0=x0) + + x = self.transformer.ln_f(x) + logits = self.lm_head(x) + return logits diff --git a/modal/_tokenizer.py b/modal/_tokenizer.py new file mode 100644 index 00000000..59fff2ba --- /dev/null +++ b/modal/_tokenizer.py @@ -0,0 +1,79 @@ +""" +Minimal standalone tokenizer for Modal inference. +Uses tiktoken for fast encoding/decoding with nanochat's special tokens. +""" + +import os +import pickle +import tiktoken + + +# nanochat special tokens +SPECIAL_TOKENS = { + "<|bos|>": 0, + "<|user_start|>": 1, + "<|user_end|>": 2, + "<|assistant_start|>": 3, + "<|assistant_end|>": 4, + "<|python_start|>": 5, + "<|python_end|>": 6, + "<|output_start|>": 7, + "<|output_end|>": 8, +} + +# GPT-4 split pattern +SPLIT_PATTERN = r"(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+" + + +class NanochatTokenizer: + def __init__(self, model_dir: str): + token_bytes_path = os.path.join(model_dir, "token_bytes.pt") + tokenizer_pkl_path = os.path.join(model_dir, "tokenizer.pkl") + + if os.path.exists(tokenizer_pkl_path): + with open(tokenizer_pkl_path, "rb") as f: + loaded = pickle.load(f) + # Handle different pickle formats + if isinstance(loaded, dict): + mergeable_ranks = loaded + elif hasattr(loaded, '_mergeable_ranks'): + # It's a tiktoken Encoding object + mergeable_ranks = loaded._mergeable_ranks + else: + # Try to use it as a pre-built encoder + self._enc = loaded + return + elif os.path.exists(token_bytes_path): + import torch + token_bytes = torch.load(token_bytes_path, weights_only=True) + mergeable_ranks = {bytes(token_bytes[i].tolist()): i for i in range(len(token_bytes))} + else: + from huggingface_hub import hf_hub_download + path = hf_hub_download("karpathy/nanochat-d32", "tokenizer.pkl") + with open(path, "rb") as f: + mergeable_ranks = pickle.load(f) + + self._enc = tiktoken.Encoding( + name="nanochat", + pat_str=SPLIT_PATTERN, + mergeable_ranks=mergeable_ranks, + special_tokens=SPECIAL_TOKENS, + ) + + def encode(self, text: str) -> list[int]: + return self._enc.encode(text, allowed_special=set()) + + def decode(self, tokens: list[int]) -> str: + return self._enc.decode(tokens) + + def encode_special(self, token_name: str) -> list[int]: + return self._enc.encode(token_name, allowed_special="all") + + def get_vocab_size(self) -> int: + return self._enc.n_vocab + + +def get_tokenizer(model_dir: str | None = None) -> NanochatTokenizer: + if model_dir is None: + model_dir = "/weights/d20" + return NanochatTokenizer(model_dir) diff --git a/modal/serve.py b/modal/serve.py new file mode 100644 index 00000000..86ef15c4 --- /dev/null +++ b/modal/serve.py @@ -0,0 +1,273 @@ +""" +samosaChaat — Modal GPU inference endpoint. + +Downloads nanochat model weights from HuggingFace into a Modal Volume, +loads them on a GPU, and exposes an SSE streaming endpoint compatible +with the samosaChaat chat-api service. + +Deploy: modal deploy modal/serve.py +Dev: modal serve modal/serve.py +""" +from __future__ import annotations + +import json +import os +import time + +import modal + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- +MODEL_REPO = "nanochat-students/base-d20" # 1 GB, native nanochat format +MODEL_PT = "model_021400.pt" +META_JSON = "meta_021400.json" +MODEL_TAG = "d20" +GPU_TYPE = "T4" # cheapest, 16 GB VRAM — plenty for 1 GB model +VOLUME_NAME = "samosachaat-weights" + +# --------------------------------------------------------------------------- +# Modal app + image +# --------------------------------------------------------------------------- +app = modal.App("samosachaat-inference") + +# Build the container image with all dependencies +inference_image = ( + modal.Image.debian_slim(python_version="3.12") + .pip_install( + "torch==2.5.1", + "tiktoken>=0.11.0", + "tokenizers>=0.22.0", + "huggingface_hub>=0.25.0", + "fastapi>=0.115.0", + "uvicorn>=0.30.0", + extra_index_url="https://download.pytorch.org/whl/cu124", + ) + .add_local_file("modal/_model.py", "/root/_model.py") + .add_local_file("modal/_tokenizer.py", "/root/_tokenizer.py") +) + +# Persistent volume for model weights +volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True) + +# --------------------------------------------------------------------------- +# Download weights into the volume (runs once) +# --------------------------------------------------------------------------- +@app.function( + image=inference_image, + volumes={"/weights": volume}, + timeout=600, +) +def download_weights(): + """Download model weights from HuggingFace into the Modal volume.""" + from huggingface_hub import hf_hub_download + + model_dir = f"/weights/{MODEL_TAG}" + os.makedirs(model_dir, exist_ok=True) + + for filename in [MODEL_PT, META_JSON, "token_bytes.pt", "tokenizer.pkl"]: + dest = os.path.join(model_dir, filename) + if os.path.exists(dest): + print(f" Already exists: {dest}") + continue + print(f" Downloading {filename} from {MODEL_REPO}...") + path = hf_hub_download(MODEL_REPO, filename) + # Copy to volume + import shutil + shutil.copy2(path, dest) + print(f" Saved to {dest}") + + volume.commit() + print("Weights downloaded and committed to volume.") + + +# --------------------------------------------------------------------------- +# Inference class — GPU singleton +# --------------------------------------------------------------------------- +@app.cls( + image=inference_image, + volumes={"/weights": volume}, + gpu=GPU_TYPE, + scaledown_window=300, # keep warm for 5 min after last request + # concurrency handled by @modal.concurrent below + timeout=120, +) +class Inference: + model: object + tokenizer: object + engine: object + device: object + + @modal.enter() + def load_model(self): + """Called once when the container starts — loads model onto GPU.""" + import torch + import sys + + # Add the nanochat engine code path + # We inline the minimal loading logic here to avoid importing the full + # nanochat package (which has heavy deps we don't need on Modal). + print("Loading model...") + t0 = time.time() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = device + + model_dir = f"/weights/{MODEL_TAG}" + meta_path = os.path.join(model_dir, META_JSON) + model_path = os.path.join(model_dir, MODEL_PT) + + # Load meta + with open(meta_path) as f: + meta = json.load(f) + model_config = meta if "model_config" not in meta else meta["model_config"] + + # Patch missing config keys + model_config.setdefault("window_pattern", "L") + + print(f" Config: {model_config}") + + # Build model + # We need the GPT class — download it from the repo itself + # For simplicity, we define a minimal inline version that matches nanochat + from _model import GPT, GPTConfig + + config = GPTConfig(**model_config) + model_data = torch.load(model_path, map_location=device, weights_only=False) + + # Fix torch compile prefix + model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()} + + # Patch missing keys + n_layer = config.n_layer + if "resid_lambdas" not in model_data: + model_data["resid_lambdas"] = torch.ones(n_layer) + if "x0_lambdas" not in model_data: + model_data["x0_lambdas"] = torch.zeros(n_layer) + + # Auto-detect architecture from checkpoint + # Convert bfloat16 weights to float32 for compatibility + model_data = { + k: v.float() if v.dtype == torch.bfloat16 else v + for k, v in model_data.items() + } + + # Auto-detect architecture from checkpoint + model = GPT.from_state_dict(config, model_data) + model.to(device) + model.init_weights() + model.load_state_dict(model_data, strict=True, assign=True) + model.eval() + + self.model = model + self.config = config + + # Load tokenizer + from _tokenizer import get_tokenizer + self.tokenizer = get_tokenizer(model_dir) + + dt = time.time() - t0 + print(f"Model loaded in {dt:.1f}s on {device}") + + @modal.fastapi_endpoint(method="POST", docs=True) + async def generate(self, request: dict): + """ + Streaming chat endpoint — SSE compatible with samosaChaat format. + + Input: {"messages": [{"role": "user", "content": "..."}], "temperature": 0.8, "max_tokens": 512, "top_k": 50} + Output: SSE stream of data: {"token": "...", "gpu": 0} then data: {"done": true} + """ + import torch + from fastapi.responses import StreamingResponse + + messages = request.get("messages", []) + temperature = min(max(request.get("temperature", 0.8), 0.0), 2.0) + max_tokens = min(max(request.get("max_tokens", 512), 1), 2048) + top_k = min(max(request.get("top_k", 50), 0), 200) + + # Build token sequence from messages + tokens = [] + bos = self.tokenizer.encode_special("<|bos|>") + user_start = self.tokenizer.encode_special("<|user_start|>") + user_end = self.tokenizer.encode_special("<|user_end|>") + assistant_start = self.tokenizer.encode_special("<|assistant_start|>") + assistant_end = self.tokenizer.encode_special("<|assistant_end|>") + + tokens.extend(bos) + for msg in messages: + if msg["role"] == "user": + tokens.extend(user_start) + tokens.extend(self.tokenizer.encode(msg["content"])) + tokens.extend(user_end) + elif msg["role"] == "assistant": + tokens.extend(assistant_start) + tokens.extend(self.tokenizer.encode(msg["content"])) + tokens.extend(assistant_end) + # Prompt the model to generate an assistant response + tokens.extend(assistant_start) + + # Truncate to fit context + max_context = self.config.sequence_len - max_tokens + if len(tokens) > max_context: + tokens = tokens[-max_context:] + + async def stream(): + input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device) + + with torch.no_grad(): + generated = [] + for _ in range(max_tokens): + # Forward pass + logits = self.model(input_ids) + next_logits = logits[:, -1, :] + + # Temperature + if temperature > 0: + next_logits = next_logits / temperature + + # Top-k filtering + if top_k > 0: + v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1))) + next_logits[next_logits < v[:, [-1]]] = float('-inf') + + # Sample + probs = torch.softmax(next_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + + token_id = next_token.item() + + # Check for stop tokens + if token_id in [t[0] for t in [assistant_end, bos]]: + break + + # Decode and yield + token_text = self.tokenizer.decode([token_id]) + yield f"data: {json.dumps({'token': token_text, 'gpu': 0})}\n\n" + + # Append for next iteration + input_ids = torch.cat([input_ids, next_token], dim=1) + + # Truncate if exceeding sequence length + if input_ids.size(1) > self.config.sequence_len: + input_ids = input_ids[:, -self.config.sequence_len:] + + yield f"data: {json.dumps({'done': True})}\n\n" + + return StreamingResponse( + stream(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + @modal.fastapi_endpoint(method="GET", docs=True) + def health(self): + return { + "status": "ok", + "model": MODEL_TAG, + "gpu": GPU_TYPE, + "ready": self.model is not None, + }