From 1d2a76eec49cb952422c01e1bbfe08a3181e42ec Mon Sep 17 00:00:00 2001 From: Manmohan <66306483+manmohan659@users.noreply.github.com> Date: Thu, 16 Apr 2026 19:55:16 -0400 Subject: [PATCH] feat: deploy d24 SFT + polished UI redesign with dark mode (#39) * feat(inference): deploy d24 SFT weights to Modal Repoint Modal inference app from the broken d20 checkpoint to our own ManmohanSharma/nanochat-d24 SFT step 484. Rewrites the standalone model as an inference-only port of nanochat/gpt.py so the modern architecture (smear gate, per-layer value embeddings, ve_gate, backout, sliding window attention via SDPA, rotary base 100000, padded vocab, logit softcap) loads cleanly from the checkpoint. Tokenizer loads the pickled tiktoken encoding directly so special tokens end up at their true IDs (32759-32767), and the stop check uses that set instead of hardcoded 0-8. GPU bumped to L4 for headroom. HF token sourced from the 'huggingface' Modal secret. Co-Authored-By: Claude Opus 4.7 (1M context) * feat(frontend): polished redesign with serif display + dark mode Lifts the craft level of the landing and chat UI without changing the desi identity. Adds Fraunces for display headlines, a floating pill LandingNav, a saffron-glow hero with a large serif headline and black pill CTAs, and three gradient-tiled feature cards with inline SVG glyphs replacing the emoji cards. The chat empty state is now a serif greeting with pill-chip prompt starters, and ChatInput is a single rounded pod so the send button sits inside the input (fixes the misaligned floating button). Adds a class-based dark mode across the chat surfaces with a sun/moon toggle in the sidebar footer, powered by a small useTheme hook and a no-flash init script in the root layout. Co-Authored-By: Claude Opus 4.7 (1M context) * chore(frontend): add ESLint config so CI lint step passes next lint was failing with an interactive prompt because the repo had no ESLint config. Adds a minimal next/core-web-vitals extends and drops the now-unloadable @typescript-eslint/no-explicit-any disable directive in the stream proxy by narrowing the body type to unknown. Co-Authored-By: Claude Opus 4.7 (1M context) --------- Co-authored-by: Claude Opus 4.7 (1M context) --- modal/_model.py | 306 ++++++++++-------- modal/_tokenizer.py | 42 +-- modal/serve.py | 91 ++++-- services/frontend/.eslintrc.json | 3 + .../frontend/app/api/chat/stream/route.ts | 3 +- services/frontend/app/chat/page.tsx | 6 +- services/frontend/app/globals.css | 21 ++ services/frontend/app/layout.tsx | 31 +- services/frontend/app/page.tsx | 6 +- services/frontend/components/LandingNav.tsx | 77 +++-- .../frontend/components/chat/ChatInput.tsx | 82 +++-- .../frontend/components/chat/ChatWindow.tsx | 11 +- .../frontend/components/chat/EmptyState.tsx | 70 ++-- .../components/chat/MessageBubble.tsx | 10 +- services/frontend/components/chat/Sidebar.tsx | 65 ++-- .../frontend/components/landing/Features.tsx | 174 ++++++++-- services/frontend/components/landing/Hero.tsx | 149 +++++---- services/frontend/hooks/useTheme.ts | 38 +++ services/frontend/tailwind.config.ts | 20 ++ 19 files changed, 766 insertions(+), 439 deletions(-) create mode 100644 services/frontend/.eslintrc.json create mode 100644 services/frontend/hooks/useTheme.ts diff --git a/modal/_model.py b/modal/_model.py index a6a1da47..f576b005 100644 --- a/modal/_model.py +++ b/modal/_model.py @@ -1,202 +1,244 @@ """ -Minimal standalone GPT model for Modal inference. -Extracted from nanochat/gpt.py — only the forward-pass code needed for inference. -No training, no DDP, no flash_attention dependency. +Inference-only port of nanochat/gpt.py. + +Matches the actual nanochat GPT architecture used by d24 SFT checkpoints: +- Smear gate (cheap bigram mixing) +- Backout (mid-layer residual subtraction) +- Per-layer value embeddings (alternating layers, last layer always) +- ve_gate per layer with value embedding +- Sliding-window attention (window_pattern, e.g. "SSSL"), via SDPA +- Rotary embeddings with base=100000, split-halves layout +- Padded vocab (multiple of 64) +- Softcap on logits +- No KV cache (naive autoregressive generate is fine for short responses) """ +from __future__ import annotations from dataclasses import dataclass + import torch import torch.nn as nn import torch.nn.functional as F -import math @dataclass class GPTConfig: sequence_len: int = 2048 - vocab_size: int = 65536 - n_layer: int = 20 - n_head: int = 10 - n_kv_head: int = 10 - n_embd: int = 1280 - window_pattern: str = "L" + vocab_size: int = 32768 + n_layer: int = 24 + n_head: int = 12 + n_kv_head: int = 12 + n_embd: int = 1536 + window_pattern: str = "SSSL" -class RMSNorm(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim +def _norm(x): + return F.rms_norm(x, (x.size(-1),)) + + +class Linear(nn.Linear): + """nn.Linear that casts weights to match input dtype in forward.""" def forward(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6) + return F.linear(x, self.weight.to(dtype=x.dtype)) -class RotaryEmbedding(nn.Module): - def __init__(self, dim, max_seq_len=2048): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.max_seq_len = max_seq_len - - def forward(self, x, offset=0): - seq_len = x.shape[-2] - t = torch.arange(offset, offset + seq_len, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - return emb.cos(), emb.sin() +def has_ve(layer_idx: int, n_layer: int) -> bool: + """Layers with a value embedding (alternating, last layer always included).""" + return layer_idx % 2 == (n_layer - 1) % 2 -def apply_rotary_pos_emb(q, k, cos, sin): - def rotate_half(x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed +def apply_rotary_emb(x, cos, sin): + assert x.ndim == 4 + d = x.shape[3] // 2 + x1, x2 = x[..., :d], x[..., d:] + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat([y1, y2], 3) class CausalSelfAttention(nn.Module): - def __init__(self, config: GPTConfig, use_v_emb: bool = False): + def __init__(self, config: GPTConfig, layer_idx: int): super().__init__() + self.layer_idx = layer_idx self.n_head = config.n_head self.n_kv_head = config.n_kv_head - self.head_dim = config.n_embd // config.n_head self.n_embd = config.n_embd - self.use_v_emb = use_v_emb + self.head_dim = self.n_embd // self.n_head + self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False) + self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) + self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) + self.c_proj = Linear(self.n_embd, self.n_embd, bias=False) + self.ve_gate_channels = 12 + if has_ve(layer_idx, config.n_layer): + self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) + else: + self.ve_gate = None - self.c_q = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=False) - self.c_k = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=False) - self.c_v = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=False) - self.c_proj = nn.Linear(config.n_head * self.head_dim, config.n_embd, bias=False) - - self.q_norm = RMSNorm(self.head_dim) - self.k_norm = RMSNorm(self.head_dim) - - if use_v_emb: - self.v_emb = nn.Parameter(torch.zeros(1, config.n_kv_head, config.sequence_len, self.head_dim)) - - self.rotary = RotaryEmbedding(self.head_dim, config.sequence_len) - - def forward(self, x): + def forward(self, x, ve, cos_sin, window_size): B, T, C = x.size() + # (B, T, H, D) layout + q = self.c_q(x).view(B, T, self.n_head, self.head_dim) + k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) + v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) - q = self.c_q(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) - k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) - v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) + if ve is not None: + ve = ve.view(B, T, self.n_kv_head, self.head_dim) + gate = 3.0 * torch.sigmoid(self.ve_gate(x[..., : self.ve_gate_channels])) # (B, T, n_kv_head) + v = v + gate.unsqueeze(-1) * ve - # QK norm - q = self.q_norm(q) - k = self.k_norm(k) + cos, sin = cos_sin + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q, k = _norm(q), _norm(k) + q = q * 1.2 + k = k * 1.2 - # Rotary embeddings - cos, sin = self.rotary(q) - q, k = apply_rotary_pos_emb(q, k, cos, sin) + # SDPA wants (B, H, T, D) + q_sdpa = q.transpose(1, 2) + k_sdpa = k.transpose(1, 2) + v_sdpa = v.transpose(1, 2) + enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) - # GQA: repeat k,v if n_kv_head < n_head - if self.n_kv_head < self.n_head: - rep = self.n_head // self.n_kv_head - k = k.repeat_interleave(rep, dim=1) - v = v.repeat_interleave(rep, dim=1) + window = window_size[0] + if window < 0 or window >= T: + y = F.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa, is_causal=True, enable_gqa=enable_gqa) + else: + # Sliding window mask (left=window) + device = q_sdpa.device + row_idx = torch.arange(T, device=device).unsqueeze(1) + col_idx = torch.arange(T, device=device).unsqueeze(0) + mask = (col_idx <= row_idx) & ((row_idx - col_idx) <= window) + y = F.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa, attn_mask=mask, enable_gqa=enable_gqa) - # Value embeddings (if enabled) - if self.use_v_emb: - v = v + self.v_emb[:, :, :T, :] - - # Scaled dot-product attention (PyTorch native, causal) - y = F.scaled_dot_product_attention(q, k, v, is_causal=True) - - y = y.transpose(1, 2).contiguous().view(B, T, C) + y = y.transpose(1, 2).contiguous().view(B, T, -1) return self.c_proj(y) class MLP(nn.Module): - def __init__(self, config: GPTConfig, gated: bool = False): + def __init__(self, config: GPTConfig): super().__init__() - self.gated = gated - if gated: - hidden = int(config.n_embd * 8 / 3) - hidden = ((hidden + 63) // 64) * 64 - self.c_fc = nn.Linear(config.n_embd, hidden, bias=False) - self.c_fc2 = nn.Linear(config.n_embd, hidden, bias=False) - self.c_proj = nn.Linear(hidden, config.n_embd, bias=False) - else: - hidden = 4 * config.n_embd - self.c_fc = nn.Linear(config.n_embd, hidden, bias=False) - self.c_proj = nn.Linear(hidden, config.n_embd, bias=False) + self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=False) + self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=False) def forward(self, x): - if self.gated: - a = self.c_fc(x) - b = self.c_fc2(x) - return self.c_proj(F.relu(a).pow(2) * b) - else: - return self.c_proj(F.relu(self.c_fc(x)).pow(2)) + x = self.c_fc(x) + x = F.relu(x).square() + x = self.c_proj(x) + return x class Block(nn.Module): - def __init__(self, config: GPTConfig, layer_idx: int, gated_mlp: bool = False, use_v_emb: bool = False): + def __init__(self, config: GPTConfig, layer_idx: int): super().__init__() - self.ln_1 = RMSNorm(config.n_embd) - self.attn = CausalSelfAttention(config, use_v_emb=use_v_emb) - self.ln_2 = RMSNorm(config.n_embd) - self.mlp = MLP(config, gated=gated_mlp) - self.layer_idx = layer_idx + self.attn = CausalSelfAttention(config, layer_idx) + self.mlp = MLP(config) - def forward(self, x, resid_lambda=1.0, x0_lambda=0.0, x0=None): - h = x * resid_lambda + self.attn(self.ln_1(x)) - if x0 is not None and x0_lambda != 0.0: - h = h + x0_lambda * x0 - h2 = h * resid_lambda + self.mlp(self.ln_2(h)) - if x0 is not None and x0_lambda != 0.0: - h2 = h2 + x0_lambda * x0 - return h2 + def forward(self, x, ve, cos_sin, window_size): + x = x + self.attn(_norm(x), ve, cos_sin, window_size) + x = x + self.mlp(_norm(x)) + return x + + +def _compute_window_sizes(config: GPTConfig): + pattern = config.window_pattern.upper() + long_window = config.sequence_len + short_window = -(-long_window // 4 // 128) * 128 + char_to_window = {"L": (long_window, 0), "S": (short_window, 0)} + sizes = [char_to_window[pattern[i % len(pattern)]] for i in range(config.n_layer)] + sizes[-1] = (long_window, 0) + return sizes + + +def _precompute_rotary(seq_len, head_dim, base=100000, device="cpu", dtype=torch.float32): + channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) + inv_freq = 1.0 / (base ** (channel_range / head_dim)) + t = torch.arange(seq_len, dtype=torch.float32, device=device) + freqs = torch.outer(t, inv_freq) + cos = freqs.cos().to(dtype)[None, :, None, :] + sin = freqs.sin().to(dtype)[None, :, None, :] + return cos, sin class GPT(nn.Module): - def __init__(self, config: GPTConfig, gated_mlp: bool = False, use_v_emb: bool = False): + def __init__(self, config: GPTConfig, pad_vocab_size_to: int = 64): super().__init__() self.config = config + self.window_sizes = _compute_window_sizes(config) - self.transformer = nn.ModuleDict(dict( - wte=nn.Embedding(config.vocab_size, config.n_embd), - norm_emb=RMSNorm(config.n_embd), - h=nn.ModuleList([Block(config, i, gated_mlp=gated_mlp, use_v_emb=use_v_emb) for i in range(config.n_layer)]), - ln_f=RMSNorm(config.n_embd), - )) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + padded = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to + self.padded_vocab_size = padded + + self.transformer = nn.ModuleDict({ + "wte": nn.Embedding(padded, config.n_embd), + "h": nn.ModuleList([Block(config, i) for i in range(config.n_layer)]), + }) + self.lm_head = Linear(config.n_embd, padded, bias=False) - # Residual lambdas (per-layer scaling) self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) + self.smear_gate = Linear(24, 1, bias=False) + self.smear_lambda = nn.Parameter(torch.zeros(1)) + self.backout_lambda = nn.Parameter(0.2 * torch.ones(1)) + + head_dim = config.n_embd // config.n_head + kv_dim = config.n_kv_head * head_dim + self.value_embeds = nn.ModuleDict( + {str(i): nn.Embedding(padded, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)} + ) + + # Rotary buffers (registered non-persistent — recomputed in init_rotary) + self.rotary_seq_len = config.sequence_len * 10 + self.register_buffer("cos", torch.zeros(1), persistent=False) + self.register_buffer("sin", torch.zeros(1), persistent=False) + @classmethod def from_state_dict(cls, config: GPTConfig, state_dict: dict): - """Auto-detect architecture features from checkpoint keys.""" - gated = any("c_fc2" in k for k in state_dict) - v_emb = any("v_emb" in k for k in state_dict) - model = cls(config, gated_mlp=gated, use_v_emb=v_emb) - return model + # Architecture is fixed for this checkpoint family; kept for API compat. + return cls(config) + def init_rotary(self, device, dtype): + head_dim = self.config.n_embd // self.config.n_head + cos, sin = _precompute_rotary(self.rotary_seq_len, head_dim, base=100000, device=device, dtype=dtype) + self.cos = cos + self.sin = sin + + # Kept for compatibility with serve.py's existing init_weights() call. def init_weights(self): - """Initialize rotary embeddings and value embeddings.""" - for module in self.modules(): - if isinstance(module, RotaryEmbedding): - inv_freq = 1.0 / (10000 ** (torch.arange(0, module.inv_freq.shape[0] * 2, 2).float() / (module.inv_freq.shape[0] * 2))) - module.inv_freq.copy_(inv_freq) + pass def forward(self, idx): B, T = idx.size() - assert T <= self.config.sequence_len, f"Input length {T} exceeds max {self.config.sequence_len}" + assert T <= self.cos.size(1), f"Sequence length {T} exceeds rotary cache {self.cos.size(1)}" + cos_sin = self.cos[:, :T], self.sin[:, :T] x = self.transformer.wte(idx) - x = self.transformer.norm_emb(x) - x0 = x # save for residual connections + x = _norm(x) + # Smear: bigram mixing (training/prefill path; T >= 1 — guarded for T==1) + if T > 1: + gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24])) + x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1) + + x0 = x + n_layer = self.config.n_layer + backout_layer = n_layer // 2 + x_backout = None for i, block in enumerate(self.transformer.h): - rl = self.resid_lambdas[i].item() - xl = self.x0_lambdas[i].item() - x = block(x, resid_lambda=rl, x0_lambda=xl, x0=x0) + x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 + ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None + x = block(x, ve, cos_sin, self.window_sizes[i]) + if i == backout_layer: + x_backout = x - x = self.transformer.ln_f(x) + if x_backout is not None: + x = x - self.backout_lambda.to(x.dtype) * x_backout + x = _norm(x) + + softcap = 15.0 logits = self.lm_head(x) + logits = logits[..., : self.config.vocab_size] + logits = logits.float() + logits = softcap * torch.tanh(logits / softcap) return logits diff --git a/modal/_tokenizer.py b/modal/_tokenizer.py index 59fff2ba..f630afaf 100644 --- a/modal/_tokenizer.py +++ b/modal/_tokenizer.py @@ -1,14 +1,16 @@ """ Minimal standalone tokenizer for Modal inference. -Uses tiktoken for fast encoding/decoding with nanochat's special tokens. + +Loads the pickled tiktoken Encoding from a nanochat tokenizer/ directory and +exposes encode / decode / encode_special methods used by serve.py. """ import os import pickle + import tiktoken -# nanochat special tokens SPECIAL_TOKENS = { "<|bos|>": 0, "<|user_start|>": 1, @@ -21,26 +23,26 @@ SPECIAL_TOKENS = { "<|output_end|>": 8, } -# GPT-4 split pattern -SPLIT_PATTERN = r"(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+" +# nanochat split pattern (matches nanochat/tokenizer.py) +SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" class NanochatTokenizer: def __init__(self, model_dir: str): + pkl_path = os.path.join(model_dir, "tokenizer.pkl") token_bytes_path = os.path.join(model_dir, "token_bytes.pt") - tokenizer_pkl_path = os.path.join(model_dir, "tokenizer.pkl") - if os.path.exists(tokenizer_pkl_path): - with open(tokenizer_pkl_path, "rb") as f: + if os.path.exists(pkl_path): + with open(pkl_path, "rb") as f: loaded = pickle.load(f) - # Handle different pickle formats + if isinstance(loaded, tiktoken.Encoding): + self._enc = loaded + return if isinstance(loaded, dict): mergeable_ranks = loaded - elif hasattr(loaded, '_mergeable_ranks'): - # It's a tiktoken Encoding object + elif hasattr(loaded, "_mergeable_ranks"): mergeable_ranks = loaded._mergeable_ranks else: - # Try to use it as a pre-built encoder self._enc = loaded return elif os.path.exists(token_bytes_path): @@ -48,32 +50,30 @@ class NanochatTokenizer: token_bytes = torch.load(token_bytes_path, weights_only=True) mergeable_ranks = {bytes(token_bytes[i].tolist()): i for i in range(len(token_bytes))} else: - from huggingface_hub import hf_hub_download - path = hf_hub_download("karpathy/nanochat-d32", "tokenizer.pkl") - with open(path, "rb") as f: - mergeable_ranks = pickle.load(f) + raise FileNotFoundError(f"No tokenizer found in {model_dir}") + # nanochat appends specials at the end of the merge table + offset = len(mergeable_ranks) + special_tokens = {name: offset + i for i, name in enumerate(SPECIAL_TOKENS)} self._enc = tiktoken.Encoding( name="nanochat", pat_str=SPLIT_PATTERN, mergeable_ranks=mergeable_ranks, - special_tokens=SPECIAL_TOKENS, + special_tokens=special_tokens, ) def encode(self, text: str) -> list[int]: - return self._enc.encode(text, allowed_special=set()) + return self._enc.encode_ordinary(text) def decode(self, tokens: list[int]) -> str: return self._enc.decode(tokens) def encode_special(self, token_name: str) -> list[int]: - return self._enc.encode(token_name, allowed_special="all") + return [self._enc.encode_single_token(token_name)] def get_vocab_size(self) -> int: return self._enc.n_vocab -def get_tokenizer(model_dir: str | None = None) -> NanochatTokenizer: - if model_dir is None: - model_dir = "/weights/d20" +def get_tokenizer(model_dir: str) -> NanochatTokenizer: return NanochatTokenizer(model_dir) diff --git a/modal/serve.py b/modal/serve.py index 86ef15c4..15a15651 100644 --- a/modal/serve.py +++ b/modal/serve.py @@ -19,12 +19,15 @@ import modal # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- -MODEL_REPO = "nanochat-students/base-d20" # 1 GB, native nanochat format -MODEL_PT = "model_021400.pt" -META_JSON = "meta_021400.json" -MODEL_TAG = "d20" -GPU_TYPE = "T4" # cheapest, 16 GB VRAM — plenty for 1 GB model +MODEL_REPO = "ManmohanSharma/nanochat-d24" +MODEL_PT = "chatsft_checkpoints/d24/model_000484.pt" +META_JSON = "chatsft_checkpoints/d24/meta_000484.json" +TOKENIZER_PKL = "tokenizer/tokenizer.pkl" +TOKEN_BYTES = "tokenizer/token_bytes.pt" +MODEL_TAG = "d24-sft" +GPU_TYPE = "L4" # 24 GB VRAM — fits 4 GB bf16 model loaded as fp32 VOLUME_NAME = "samosachaat-weights" +HF_SECRET_NAME = "huggingface" # Modal secret containing HF_TOKEN # --------------------------------------------------------------------------- # Modal app + image @@ -56,24 +59,34 @@ volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True) @app.function( image=inference_image, volumes={"/weights": volume}, - timeout=600, + secrets=[modal.Secret.from_name(HF_SECRET_NAME)], + timeout=1800, ) def download_weights(): """Download model weights from HuggingFace into the Modal volume.""" + import shutil from huggingface_hub import hf_hub_download model_dir = f"/weights/{MODEL_TAG}" os.makedirs(model_dir, exist_ok=True) - for filename in [MODEL_PT, META_JSON, "token_bytes.pt", "tokenizer.pkl"]: - dest = os.path.join(model_dir, filename) + token = os.environ.get("HF_TOKEN") + + # (HF source path, local filename in volume) + files = [ + (MODEL_PT, "model.pt"), + (META_JSON, "meta.json"), + (TOKENIZER_PKL, "tokenizer.pkl"), + (TOKEN_BYTES, "token_bytes.pt"), + ] + + for src, local_name in files: + dest = os.path.join(model_dir, local_name) if os.path.exists(dest): print(f" Already exists: {dest}") continue - print(f" Downloading {filename} from {MODEL_REPO}...") - path = hf_hub_download(MODEL_REPO, filename) - # Copy to volume - import shutil + print(f" Downloading {src} from {MODEL_REPO}...") + path = hf_hub_download(MODEL_REPO, src, token=token) shutil.copy2(path, dest) print(f" Saved to {dest}") @@ -114,14 +127,26 @@ class Inference: self.device = device model_dir = f"/weights/{MODEL_TAG}" - meta_path = os.path.join(model_dir, META_JSON) - model_path = os.path.join(model_dir, MODEL_PT) + meta_path = os.path.join(model_dir, "meta.json") + model_path = os.path.join(model_dir, "model.pt") # Load meta with open(meta_path) as f: meta = json.load(f) model_config = meta if "model_config" not in meta else meta["model_config"] + # Normalize config key names (HF format → nanochat format) + # Map HF config keys → nanochat GPTConfig keys + seq_len = model_config.pop("n_positions", None) or model_config.pop("n_ctx", None) + if seq_len and "sequence_len" not in model_config: + model_config["sequence_len"] = seq_len + # Also remove n_ctx if sequence_len was already set + model_config.pop("n_ctx", None) + model_config.pop("n_positions", None) + # Remove HF-specific keys that GPTConfig doesn't accept + for k in ["architectures", "model_type", "rotary", "rotary_base", "tie_word_embeddings"]: + model_config.pop(k, None) + # Patch missing config keys model_config.setdefault("window_pattern", "L") @@ -135,37 +160,36 @@ class Inference: config = GPTConfig(**model_config) model_data = torch.load(model_path, map_location=device, weights_only=False) - # Fix torch compile prefix + # Strip torch.compile prefix model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()} - # Patch missing keys - n_layer = config.n_layer - if "resid_lambdas" not in model_data: - model_data["resid_lambdas"] = torch.ones(n_layer) - if "x0_lambdas" not in model_data: - model_data["x0_lambdas"] = torch.zeros(n_layer) - - # Auto-detect architecture from checkpoint - # Convert bfloat16 weights to float32 for compatibility + # Convert bfloat16 weights to float32 for compatibility on non-Hopper GPUs model_data = { k: v.float() if v.dtype == torch.bfloat16 else v for k, v in model_data.items() } - # Auto-detect architecture from checkpoint model = GPT.from_state_dict(config, model_data) - model.to(device) - model.init_weights() model.load_state_dict(model_data, strict=True, assign=True) + model.to(device) + model.init_rotary(device=device, dtype=torch.float32) model.eval() self.model = model self.config = config # Load tokenizer - from _tokenizer import get_tokenizer + from _tokenizer import get_tokenizer, SPECIAL_TOKENS self.tokenizer = get_tokenizer(model_dir) + # Resolve actual special-token IDs (nanochat appends specials at end of vocab) + self.special_token_ids = set() + for name in SPECIAL_TOKENS: + ids = self.tokenizer.encode_special(name) + self.special_token_ids.update(ids) + self.assistant_end_id = self.tokenizer.encode_special("<|assistant_end|>")[0] + print(f" Special token IDs: {sorted(self.special_token_ids)}") + dt = time.time() - t0 print(f"Model loaded in {dt:.1f}s on {device}") @@ -236,12 +260,15 @@ class Inference: token_id = next_token.item() - # Check for stop tokens - if token_id in [t[0] for t in [assistant_end, bos]]: + # Stop on any special token (assistant_end, bos, etc.) + if token_id in self.special_token_ids: break - # Decode and yield - token_text = self.tokenizer.decode([token_id]) + # Decode and yield (skip tokens that can't be decoded) + try: + token_text = self.tokenizer.decode([token_id]) + except (KeyError, Exception): + continue yield f"data: {json.dumps({'token': token_text, 'gpu': 0})}\n\n" # Append for next iteration diff --git a/services/frontend/.eslintrc.json b/services/frontend/.eslintrc.json new file mode 100644 index 00000000..bffb357a --- /dev/null +++ b/services/frontend/.eslintrc.json @@ -0,0 +1,3 @@ +{ + "extends": "next/core-web-vitals" +} diff --git a/services/frontend/app/api/chat/stream/route.ts b/services/frontend/app/api/chat/stream/route.ts index f2140e4f..c5af8a2a 100644 --- a/services/frontend/app/api/chat/stream/route.ts +++ b/services/frontend/app/api/chat/stream/route.ts @@ -18,8 +18,7 @@ function sseEvent(data: Record) { return encoder.encode(`data: ${JSON.stringify(data)}\n\n`); } -// eslint-disable-next-line @typescript-eslint/no-explicit-any -async function proxyUpstream(body: any, upstreamUrl: string, authHeader: string | null) { +async function proxyUpstream(body: unknown, upstreamUrl: string, authHeader: string | null) { const headers: Record = { 'Content-Type': 'application/json' }; if (authHeader) headers['Authorization'] = authHeader; diff --git a/services/frontend/app/chat/page.tsx b/services/frontend/app/chat/page.tsx index bea612b0..6a76df24 100644 --- a/services/frontend/app/chat/page.tsx +++ b/services/frontend/app/chat/page.tsx @@ -11,7 +11,7 @@ function ChatContent() { const { authenticated, loading } = useAuth(); if (loading) { - return
Loading...
; + return
Loading…
; } if (!authenticated) { redirect('/login'); @@ -19,7 +19,7 @@ function ChatContent() { } return ( -
+
@@ -28,7 +28,7 @@ function ChatContent() { export default function ChatPage() { return ( - Loading...}> + Loading…}> ); diff --git a/services/frontend/app/globals.css b/services/frontend/app/globals.css index 23df16dc..56f059ba 100644 --- a/services/frontend/app/globals.css +++ b/services/frontend/app/globals.css @@ -13,6 +13,10 @@ --light-cream: #fffdf5; } +.dark { + color-scheme: dark; +} + html, body { height: 100%; } body { @@ -21,6 +25,10 @@ body { overflow-x: hidden; } +.dark body { + @apply bg-ink text-ink-text; +} + /* Markdown prose tweaks inside message bubbles */ .markdown-body > *:first-child { margin-top: 0 !important; } .markdown-body > *:last-child { margin-bottom: 0 !important; } @@ -30,6 +38,9 @@ body { .markdown-body code:not(pre code) { @apply px-1 py-0.5 rounded bg-cream-light border border-cream-border text-brown text-[0.9em]; } +.dark .markdown-body code:not(pre code) { + @apply bg-ink-elev border-ink-border text-saffron-soft; +} .markdown-body p { @apply my-2 leading-relaxed; } .markdown-body ul { @apply list-disc pl-6 my-2; } .markdown-body ol { @apply list-decimal pl-6 my-2; } @@ -39,12 +50,22 @@ body { .markdown-body blockquote { @apply border-l-4 border-cream-border pl-4 italic text-brown-light my-2; } +.dark .markdown-body blockquote { + @apply border-ink-border text-ink-text-soft; +} .markdown-body a { @apply text-chutney-green underline hover:text-gold; } /* Scrollbar tuning */ .nice-scrollbar::-webkit-scrollbar { width: 6px; height: 6px; } .nice-scrollbar::-webkit-scrollbar-thumb { background: #e0d5c0; border-radius: 3px; } .nice-scrollbar::-webkit-scrollbar-thumb:hover { background: var(--warm-grey); } +.dark .nice-scrollbar::-webkit-scrollbar-thumb { background: #2a2a2e; } +.dark .nice-scrollbar::-webkit-scrollbar-thumb:hover { background: #3a3a40; } /* Highlight.js minimal tweaks */ .hljs { background: transparent; } + +/* Soft serif optical sizing for display headlines */ +.font-display { + font-optical-sizing: auto; +} diff --git a/services/frontend/app/layout.tsx b/services/frontend/app/layout.tsx index 2f86f67e..cb0ea9e7 100644 --- a/services/frontend/app/layout.tsx +++ b/services/frontend/app/layout.tsx @@ -1,5 +1,5 @@ import type { Metadata, Viewport } from 'next'; -import { Baloo_2, Great_Vibes, Caveat, Inter } from 'next/font/google'; +import { Baloo_2, Great_Vibes, Caveat, Inter, Fraunces } from 'next/font/google'; import './globals.css'; const baloo = Baloo_2({ @@ -29,9 +29,17 @@ const inter = Inter({ display: 'swap', }); +const fraunces = Fraunces({ + subsets: ['latin'], + weight: ['400', '500', '600', '700'], + variable: '--font-fraunces', + display: 'swap', +}); + export const metadata: Metadata = { title: 'समोसाचाट — samosaChaat', - description: 'Crafted with care. For India, from India. A warm, desi-flavored chat experience powered by nanochat.', + description: + 'Crafted with care. For India, from India. A warm, desi-flavored chat experience powered by nanochat.', icons: { icon: '/logo.svg' }, }; @@ -42,10 +50,25 @@ export const viewport: Viewport = { viewportFit: 'cover', }; +// Set theme class before paint to avoid flash +const themeInitScript = ` +(function(){try{ + var t=localStorage.getItem('theme'); + if(t==='dark'){document.documentElement.classList.add('dark');} + else if(t==='light'){document.documentElement.classList.remove('dark');} +}catch(e){}})(); +`; + export default function RootLayout({ children }: { children: React.ReactNode }) { return ( - - + + +