diff --git a/modal/_model.py b/modal/_model.py index a6a1da47..f576b005 100644 --- a/modal/_model.py +++ b/modal/_model.py @@ -1,202 +1,244 @@ """ -Minimal standalone GPT model for Modal inference. -Extracted from nanochat/gpt.py — only the forward-pass code needed for inference. -No training, no DDP, no flash_attention dependency. +Inference-only port of nanochat/gpt.py. + +Matches the actual nanochat GPT architecture used by d24 SFT checkpoints: +- Smear gate (cheap bigram mixing) +- Backout (mid-layer residual subtraction) +- Per-layer value embeddings (alternating layers, last layer always) +- ve_gate per layer with value embedding +- Sliding-window attention (window_pattern, e.g. "SSSL"), via SDPA +- Rotary embeddings with base=100000, split-halves layout +- Padded vocab (multiple of 64) +- Softcap on logits +- No KV cache (naive autoregressive generate is fine for short responses) """ +from __future__ import annotations from dataclasses import dataclass + import torch import torch.nn as nn import torch.nn.functional as F -import math @dataclass class GPTConfig: sequence_len: int = 2048 - vocab_size: int = 65536 - n_layer: int = 20 - n_head: int = 10 - n_kv_head: int = 10 - n_embd: int = 1280 - window_pattern: str = "L" + vocab_size: int = 32768 + n_layer: int = 24 + n_head: int = 12 + n_kv_head: int = 12 + n_embd: int = 1536 + window_pattern: str = "SSSL" -class RMSNorm(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim +def _norm(x): + return F.rms_norm(x, (x.size(-1),)) + + +class Linear(nn.Linear): + """nn.Linear that casts weights to match input dtype in forward.""" def forward(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6) + return F.linear(x, self.weight.to(dtype=x.dtype)) -class RotaryEmbedding(nn.Module): - def __init__(self, dim, max_seq_len=2048): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.max_seq_len = max_seq_len - - def forward(self, x, offset=0): - seq_len = x.shape[-2] - t = torch.arange(offset, offset + seq_len, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - return emb.cos(), emb.sin() +def has_ve(layer_idx: int, n_layer: int) -> bool: + """Layers with a value embedding (alternating, last layer always included).""" + return layer_idx % 2 == (n_layer - 1) % 2 -def apply_rotary_pos_emb(q, k, cos, sin): - def rotate_half(x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed +def apply_rotary_emb(x, cos, sin): + assert x.ndim == 4 + d = x.shape[3] // 2 + x1, x2 = x[..., :d], x[..., d:] + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat([y1, y2], 3) class CausalSelfAttention(nn.Module): - def __init__(self, config: GPTConfig, use_v_emb: bool = False): + def __init__(self, config: GPTConfig, layer_idx: int): super().__init__() + self.layer_idx = layer_idx self.n_head = config.n_head self.n_kv_head = config.n_kv_head - self.head_dim = config.n_embd // config.n_head self.n_embd = config.n_embd - self.use_v_emb = use_v_emb + self.head_dim = self.n_embd // self.n_head + self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False) + self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) + self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) + self.c_proj = Linear(self.n_embd, self.n_embd, bias=False) + self.ve_gate_channels = 12 + if has_ve(layer_idx, config.n_layer): + self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) + else: + self.ve_gate = None - self.c_q = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=False) - self.c_k = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=False) - self.c_v = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=False) - self.c_proj = nn.Linear(config.n_head * self.head_dim, config.n_embd, bias=False) - - self.q_norm = RMSNorm(self.head_dim) - self.k_norm = RMSNorm(self.head_dim) - - if use_v_emb: - self.v_emb = nn.Parameter(torch.zeros(1, config.n_kv_head, config.sequence_len, self.head_dim)) - - self.rotary = RotaryEmbedding(self.head_dim, config.sequence_len) - - def forward(self, x): + def forward(self, x, ve, cos_sin, window_size): B, T, C = x.size() + # (B, T, H, D) layout + q = self.c_q(x).view(B, T, self.n_head, self.head_dim) + k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) + v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) - q = self.c_q(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) - k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) - v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) + if ve is not None: + ve = ve.view(B, T, self.n_kv_head, self.head_dim) + gate = 3.0 * torch.sigmoid(self.ve_gate(x[..., : self.ve_gate_channels])) # (B, T, n_kv_head) + v = v + gate.unsqueeze(-1) * ve - # QK norm - q = self.q_norm(q) - k = self.k_norm(k) + cos, sin = cos_sin + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q, k = _norm(q), _norm(k) + q = q * 1.2 + k = k * 1.2 - # Rotary embeddings - cos, sin = self.rotary(q) - q, k = apply_rotary_pos_emb(q, k, cos, sin) + # SDPA wants (B, H, T, D) + q_sdpa = q.transpose(1, 2) + k_sdpa = k.transpose(1, 2) + v_sdpa = v.transpose(1, 2) + enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) - # GQA: repeat k,v if n_kv_head < n_head - if self.n_kv_head < self.n_head: - rep = self.n_head // self.n_kv_head - k = k.repeat_interleave(rep, dim=1) - v = v.repeat_interleave(rep, dim=1) + window = window_size[0] + if window < 0 or window >= T: + y = F.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa, is_causal=True, enable_gqa=enable_gqa) + else: + # Sliding window mask (left=window) + device = q_sdpa.device + row_idx = torch.arange(T, device=device).unsqueeze(1) + col_idx = torch.arange(T, device=device).unsqueeze(0) + mask = (col_idx <= row_idx) & ((row_idx - col_idx) <= window) + y = F.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa, attn_mask=mask, enable_gqa=enable_gqa) - # Value embeddings (if enabled) - if self.use_v_emb: - v = v + self.v_emb[:, :, :T, :] - - # Scaled dot-product attention (PyTorch native, causal) - y = F.scaled_dot_product_attention(q, k, v, is_causal=True) - - y = y.transpose(1, 2).contiguous().view(B, T, C) + y = y.transpose(1, 2).contiguous().view(B, T, -1) return self.c_proj(y) class MLP(nn.Module): - def __init__(self, config: GPTConfig, gated: bool = False): + def __init__(self, config: GPTConfig): super().__init__() - self.gated = gated - if gated: - hidden = int(config.n_embd * 8 / 3) - hidden = ((hidden + 63) // 64) * 64 - self.c_fc = nn.Linear(config.n_embd, hidden, bias=False) - self.c_fc2 = nn.Linear(config.n_embd, hidden, bias=False) - self.c_proj = nn.Linear(hidden, config.n_embd, bias=False) - else: - hidden = 4 * config.n_embd - self.c_fc = nn.Linear(config.n_embd, hidden, bias=False) - self.c_proj = nn.Linear(hidden, config.n_embd, bias=False) + self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=False) + self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=False) def forward(self, x): - if self.gated: - a = self.c_fc(x) - b = self.c_fc2(x) - return self.c_proj(F.relu(a).pow(2) * b) - else: - return self.c_proj(F.relu(self.c_fc(x)).pow(2)) + x = self.c_fc(x) + x = F.relu(x).square() + x = self.c_proj(x) + return x class Block(nn.Module): - def __init__(self, config: GPTConfig, layer_idx: int, gated_mlp: bool = False, use_v_emb: bool = False): + def __init__(self, config: GPTConfig, layer_idx: int): super().__init__() - self.ln_1 = RMSNorm(config.n_embd) - self.attn = CausalSelfAttention(config, use_v_emb=use_v_emb) - self.ln_2 = RMSNorm(config.n_embd) - self.mlp = MLP(config, gated=gated_mlp) - self.layer_idx = layer_idx + self.attn = CausalSelfAttention(config, layer_idx) + self.mlp = MLP(config) - def forward(self, x, resid_lambda=1.0, x0_lambda=0.0, x0=None): - h = x * resid_lambda + self.attn(self.ln_1(x)) - if x0 is not None and x0_lambda != 0.0: - h = h + x0_lambda * x0 - h2 = h * resid_lambda + self.mlp(self.ln_2(h)) - if x0 is not None and x0_lambda != 0.0: - h2 = h2 + x0_lambda * x0 - return h2 + def forward(self, x, ve, cos_sin, window_size): + x = x + self.attn(_norm(x), ve, cos_sin, window_size) + x = x + self.mlp(_norm(x)) + return x + + +def _compute_window_sizes(config: GPTConfig): + pattern = config.window_pattern.upper() + long_window = config.sequence_len + short_window = -(-long_window // 4 // 128) * 128 + char_to_window = {"L": (long_window, 0), "S": (short_window, 0)} + sizes = [char_to_window[pattern[i % len(pattern)]] for i in range(config.n_layer)] + sizes[-1] = (long_window, 0) + return sizes + + +def _precompute_rotary(seq_len, head_dim, base=100000, device="cpu", dtype=torch.float32): + channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) + inv_freq = 1.0 / (base ** (channel_range / head_dim)) + t = torch.arange(seq_len, dtype=torch.float32, device=device) + freqs = torch.outer(t, inv_freq) + cos = freqs.cos().to(dtype)[None, :, None, :] + sin = freqs.sin().to(dtype)[None, :, None, :] + return cos, sin class GPT(nn.Module): - def __init__(self, config: GPTConfig, gated_mlp: bool = False, use_v_emb: bool = False): + def __init__(self, config: GPTConfig, pad_vocab_size_to: int = 64): super().__init__() self.config = config + self.window_sizes = _compute_window_sizes(config) - self.transformer = nn.ModuleDict(dict( - wte=nn.Embedding(config.vocab_size, config.n_embd), - norm_emb=RMSNorm(config.n_embd), - h=nn.ModuleList([Block(config, i, gated_mlp=gated_mlp, use_v_emb=use_v_emb) for i in range(config.n_layer)]), - ln_f=RMSNorm(config.n_embd), - )) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + padded = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to + self.padded_vocab_size = padded + + self.transformer = nn.ModuleDict({ + "wte": nn.Embedding(padded, config.n_embd), + "h": nn.ModuleList([Block(config, i) for i in range(config.n_layer)]), + }) + self.lm_head = Linear(config.n_embd, padded, bias=False) - # Residual lambdas (per-layer scaling) self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) + self.smear_gate = Linear(24, 1, bias=False) + self.smear_lambda = nn.Parameter(torch.zeros(1)) + self.backout_lambda = nn.Parameter(0.2 * torch.ones(1)) + + head_dim = config.n_embd // config.n_head + kv_dim = config.n_kv_head * head_dim + self.value_embeds = nn.ModuleDict( + {str(i): nn.Embedding(padded, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)} + ) + + # Rotary buffers (registered non-persistent — recomputed in init_rotary) + self.rotary_seq_len = config.sequence_len * 10 + self.register_buffer("cos", torch.zeros(1), persistent=False) + self.register_buffer("sin", torch.zeros(1), persistent=False) + @classmethod def from_state_dict(cls, config: GPTConfig, state_dict: dict): - """Auto-detect architecture features from checkpoint keys.""" - gated = any("c_fc2" in k for k in state_dict) - v_emb = any("v_emb" in k for k in state_dict) - model = cls(config, gated_mlp=gated, use_v_emb=v_emb) - return model + # Architecture is fixed for this checkpoint family; kept for API compat. + return cls(config) + def init_rotary(self, device, dtype): + head_dim = self.config.n_embd // self.config.n_head + cos, sin = _precompute_rotary(self.rotary_seq_len, head_dim, base=100000, device=device, dtype=dtype) + self.cos = cos + self.sin = sin + + # Kept for compatibility with serve.py's existing init_weights() call. def init_weights(self): - """Initialize rotary embeddings and value embeddings.""" - for module in self.modules(): - if isinstance(module, RotaryEmbedding): - inv_freq = 1.0 / (10000 ** (torch.arange(0, module.inv_freq.shape[0] * 2, 2).float() / (module.inv_freq.shape[0] * 2))) - module.inv_freq.copy_(inv_freq) + pass def forward(self, idx): B, T = idx.size() - assert T <= self.config.sequence_len, f"Input length {T} exceeds max {self.config.sequence_len}" + assert T <= self.cos.size(1), f"Sequence length {T} exceeds rotary cache {self.cos.size(1)}" + cos_sin = self.cos[:, :T], self.sin[:, :T] x = self.transformer.wte(idx) - x = self.transformer.norm_emb(x) - x0 = x # save for residual connections + x = _norm(x) + # Smear: bigram mixing (training/prefill path; T >= 1 — guarded for T==1) + if T > 1: + gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24])) + x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1) + + x0 = x + n_layer = self.config.n_layer + backout_layer = n_layer // 2 + x_backout = None for i, block in enumerate(self.transformer.h): - rl = self.resid_lambdas[i].item() - xl = self.x0_lambdas[i].item() - x = block(x, resid_lambda=rl, x0_lambda=xl, x0=x0) + x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 + ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None + x = block(x, ve, cos_sin, self.window_sizes[i]) + if i == backout_layer: + x_backout = x - x = self.transformer.ln_f(x) + if x_backout is not None: + x = x - self.backout_lambda.to(x.dtype) * x_backout + x = _norm(x) + + softcap = 15.0 logits = self.lm_head(x) + logits = logits[..., : self.config.vocab_size] + logits = logits.float() + logits = softcap * torch.tanh(logits / softcap) return logits diff --git a/modal/_tokenizer.py b/modal/_tokenizer.py index 59fff2ba..f630afaf 100644 --- a/modal/_tokenizer.py +++ b/modal/_tokenizer.py @@ -1,14 +1,16 @@ """ Minimal standalone tokenizer for Modal inference. -Uses tiktoken for fast encoding/decoding with nanochat's special tokens. + +Loads the pickled tiktoken Encoding from a nanochat tokenizer/ directory and +exposes encode / decode / encode_special methods used by serve.py. """ import os import pickle + import tiktoken -# nanochat special tokens SPECIAL_TOKENS = { "<|bos|>": 0, "<|user_start|>": 1, @@ -21,26 +23,26 @@ SPECIAL_TOKENS = { "<|output_end|>": 8, } -# GPT-4 split pattern -SPLIT_PATTERN = r"(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+" +# nanochat split pattern (matches nanochat/tokenizer.py) +SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" class NanochatTokenizer: def __init__(self, model_dir: str): + pkl_path = os.path.join(model_dir, "tokenizer.pkl") token_bytes_path = os.path.join(model_dir, "token_bytes.pt") - tokenizer_pkl_path = os.path.join(model_dir, "tokenizer.pkl") - if os.path.exists(tokenizer_pkl_path): - with open(tokenizer_pkl_path, "rb") as f: + if os.path.exists(pkl_path): + with open(pkl_path, "rb") as f: loaded = pickle.load(f) - # Handle different pickle formats + if isinstance(loaded, tiktoken.Encoding): + self._enc = loaded + return if isinstance(loaded, dict): mergeable_ranks = loaded - elif hasattr(loaded, '_mergeable_ranks'): - # It's a tiktoken Encoding object + elif hasattr(loaded, "_mergeable_ranks"): mergeable_ranks = loaded._mergeable_ranks else: - # Try to use it as a pre-built encoder self._enc = loaded return elif os.path.exists(token_bytes_path): @@ -48,32 +50,30 @@ class NanochatTokenizer: token_bytes = torch.load(token_bytes_path, weights_only=True) mergeable_ranks = {bytes(token_bytes[i].tolist()): i for i in range(len(token_bytes))} else: - from huggingface_hub import hf_hub_download - path = hf_hub_download("karpathy/nanochat-d32", "tokenizer.pkl") - with open(path, "rb") as f: - mergeable_ranks = pickle.load(f) + raise FileNotFoundError(f"No tokenizer found in {model_dir}") + # nanochat appends specials at the end of the merge table + offset = len(mergeable_ranks) + special_tokens = {name: offset + i for i, name in enumerate(SPECIAL_TOKENS)} self._enc = tiktoken.Encoding( name="nanochat", pat_str=SPLIT_PATTERN, mergeable_ranks=mergeable_ranks, - special_tokens=SPECIAL_TOKENS, + special_tokens=special_tokens, ) def encode(self, text: str) -> list[int]: - return self._enc.encode(text, allowed_special=set()) + return self._enc.encode_ordinary(text) def decode(self, tokens: list[int]) -> str: return self._enc.decode(tokens) def encode_special(self, token_name: str) -> list[int]: - return self._enc.encode(token_name, allowed_special="all") + return [self._enc.encode_single_token(token_name)] def get_vocab_size(self) -> int: return self._enc.n_vocab -def get_tokenizer(model_dir: str | None = None) -> NanochatTokenizer: - if model_dir is None: - model_dir = "/weights/d20" +def get_tokenizer(model_dir: str) -> NanochatTokenizer: return NanochatTokenizer(model_dir) diff --git a/modal/serve.py b/modal/serve.py index 86ef15c4..15a15651 100644 --- a/modal/serve.py +++ b/modal/serve.py @@ -19,12 +19,15 @@ import modal # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- -MODEL_REPO = "nanochat-students/base-d20" # 1 GB, native nanochat format -MODEL_PT = "model_021400.pt" -META_JSON = "meta_021400.json" -MODEL_TAG = "d20" -GPU_TYPE = "T4" # cheapest, 16 GB VRAM — plenty for 1 GB model +MODEL_REPO = "ManmohanSharma/nanochat-d24" +MODEL_PT = "chatsft_checkpoints/d24/model_000484.pt" +META_JSON = "chatsft_checkpoints/d24/meta_000484.json" +TOKENIZER_PKL = "tokenizer/tokenizer.pkl" +TOKEN_BYTES = "tokenizer/token_bytes.pt" +MODEL_TAG = "d24-sft" +GPU_TYPE = "L4" # 24 GB VRAM — fits 4 GB bf16 model loaded as fp32 VOLUME_NAME = "samosachaat-weights" +HF_SECRET_NAME = "huggingface" # Modal secret containing HF_TOKEN # --------------------------------------------------------------------------- # Modal app + image @@ -56,24 +59,34 @@ volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True) @app.function( image=inference_image, volumes={"/weights": volume}, - timeout=600, + secrets=[modal.Secret.from_name(HF_SECRET_NAME)], + timeout=1800, ) def download_weights(): """Download model weights from HuggingFace into the Modal volume.""" + import shutil from huggingface_hub import hf_hub_download model_dir = f"/weights/{MODEL_TAG}" os.makedirs(model_dir, exist_ok=True) - for filename in [MODEL_PT, META_JSON, "token_bytes.pt", "tokenizer.pkl"]: - dest = os.path.join(model_dir, filename) + token = os.environ.get("HF_TOKEN") + + # (HF source path, local filename in volume) + files = [ + (MODEL_PT, "model.pt"), + (META_JSON, "meta.json"), + (TOKENIZER_PKL, "tokenizer.pkl"), + (TOKEN_BYTES, "token_bytes.pt"), + ] + + for src, local_name in files: + dest = os.path.join(model_dir, local_name) if os.path.exists(dest): print(f" Already exists: {dest}") continue - print(f" Downloading {filename} from {MODEL_REPO}...") - path = hf_hub_download(MODEL_REPO, filename) - # Copy to volume - import shutil + print(f" Downloading {src} from {MODEL_REPO}...") + path = hf_hub_download(MODEL_REPO, src, token=token) shutil.copy2(path, dest) print(f" Saved to {dest}") @@ -114,14 +127,26 @@ class Inference: self.device = device model_dir = f"/weights/{MODEL_TAG}" - meta_path = os.path.join(model_dir, META_JSON) - model_path = os.path.join(model_dir, MODEL_PT) + meta_path = os.path.join(model_dir, "meta.json") + model_path = os.path.join(model_dir, "model.pt") # Load meta with open(meta_path) as f: meta = json.load(f) model_config = meta if "model_config" not in meta else meta["model_config"] + # Normalize config key names (HF format → nanochat format) + # Map HF config keys → nanochat GPTConfig keys + seq_len = model_config.pop("n_positions", None) or model_config.pop("n_ctx", None) + if seq_len and "sequence_len" not in model_config: + model_config["sequence_len"] = seq_len + # Also remove n_ctx if sequence_len was already set + model_config.pop("n_ctx", None) + model_config.pop("n_positions", None) + # Remove HF-specific keys that GPTConfig doesn't accept + for k in ["architectures", "model_type", "rotary", "rotary_base", "tie_word_embeddings"]: + model_config.pop(k, None) + # Patch missing config keys model_config.setdefault("window_pattern", "L") @@ -135,37 +160,36 @@ class Inference: config = GPTConfig(**model_config) model_data = torch.load(model_path, map_location=device, weights_only=False) - # Fix torch compile prefix + # Strip torch.compile prefix model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()} - # Patch missing keys - n_layer = config.n_layer - if "resid_lambdas" not in model_data: - model_data["resid_lambdas"] = torch.ones(n_layer) - if "x0_lambdas" not in model_data: - model_data["x0_lambdas"] = torch.zeros(n_layer) - - # Auto-detect architecture from checkpoint - # Convert bfloat16 weights to float32 for compatibility + # Convert bfloat16 weights to float32 for compatibility on non-Hopper GPUs model_data = { k: v.float() if v.dtype == torch.bfloat16 else v for k, v in model_data.items() } - # Auto-detect architecture from checkpoint model = GPT.from_state_dict(config, model_data) - model.to(device) - model.init_weights() model.load_state_dict(model_data, strict=True, assign=True) + model.to(device) + model.init_rotary(device=device, dtype=torch.float32) model.eval() self.model = model self.config = config # Load tokenizer - from _tokenizer import get_tokenizer + from _tokenizer import get_tokenizer, SPECIAL_TOKENS self.tokenizer = get_tokenizer(model_dir) + # Resolve actual special-token IDs (nanochat appends specials at end of vocab) + self.special_token_ids = set() + for name in SPECIAL_TOKENS: + ids = self.tokenizer.encode_special(name) + self.special_token_ids.update(ids) + self.assistant_end_id = self.tokenizer.encode_special("<|assistant_end|>")[0] + print(f" Special token IDs: {sorted(self.special_token_ids)}") + dt = time.time() - t0 print(f"Model loaded in {dt:.1f}s on {device}") @@ -236,12 +260,15 @@ class Inference: token_id = next_token.item() - # Check for stop tokens - if token_id in [t[0] for t in [assistant_end, bos]]: + # Stop on any special token (assistant_end, bos, etc.) + if token_id in self.special_token_ids: break - # Decode and yield - token_text = self.tokenizer.decode([token_id]) + # Decode and yield (skip tokens that can't be decoded) + try: + token_text = self.tokenizer.decode([token_id]) + except (KeyError, Exception): + continue yield f"data: {json.dumps({'token': token_text, 'gpu': 0})}\n\n" # Append for next iteration diff --git a/services/frontend/.eslintrc.json b/services/frontend/.eslintrc.json new file mode 100644 index 00000000..bffb357a --- /dev/null +++ b/services/frontend/.eslintrc.json @@ -0,0 +1,3 @@ +{ + "extends": "next/core-web-vitals" +} diff --git a/services/frontend/app/api/chat/stream/route.ts b/services/frontend/app/api/chat/stream/route.ts index f2140e4f..c5af8a2a 100644 --- a/services/frontend/app/api/chat/stream/route.ts +++ b/services/frontend/app/api/chat/stream/route.ts @@ -18,8 +18,7 @@ function sseEvent(data: Record) { return encoder.encode(`data: ${JSON.stringify(data)}\n\n`); } -// eslint-disable-next-line @typescript-eslint/no-explicit-any -async function proxyUpstream(body: any, upstreamUrl: string, authHeader: string | null) { +async function proxyUpstream(body: unknown, upstreamUrl: string, authHeader: string | null) { const headers: Record = { 'Content-Type': 'application/json' }; if (authHeader) headers['Authorization'] = authHeader; diff --git a/services/frontend/app/chat/page.tsx b/services/frontend/app/chat/page.tsx index bea612b0..6a76df24 100644 --- a/services/frontend/app/chat/page.tsx +++ b/services/frontend/app/chat/page.tsx @@ -11,7 +11,7 @@ function ChatContent() { const { authenticated, loading } = useAuth(); if (loading) { - return
Loading...
; + return
Loading…
; } if (!authenticated) { redirect('/login'); @@ -19,7 +19,7 @@ function ChatContent() { } return ( -
+
@@ -28,7 +28,7 @@ function ChatContent() { export default function ChatPage() { return ( - Loading...}> + Loading…}> ); diff --git a/services/frontend/app/globals.css b/services/frontend/app/globals.css index 23df16dc..56f059ba 100644 --- a/services/frontend/app/globals.css +++ b/services/frontend/app/globals.css @@ -13,6 +13,10 @@ --light-cream: #fffdf5; } +.dark { + color-scheme: dark; +} + html, body { height: 100%; } body { @@ -21,6 +25,10 @@ body { overflow-x: hidden; } +.dark body { + @apply bg-ink text-ink-text; +} + /* Markdown prose tweaks inside message bubbles */ .markdown-body > *:first-child { margin-top: 0 !important; } .markdown-body > *:last-child { margin-bottom: 0 !important; } @@ -30,6 +38,9 @@ body { .markdown-body code:not(pre code) { @apply px-1 py-0.5 rounded bg-cream-light border border-cream-border text-brown text-[0.9em]; } +.dark .markdown-body code:not(pre code) { + @apply bg-ink-elev border-ink-border text-saffron-soft; +} .markdown-body p { @apply my-2 leading-relaxed; } .markdown-body ul { @apply list-disc pl-6 my-2; } .markdown-body ol { @apply list-decimal pl-6 my-2; } @@ -39,12 +50,22 @@ body { .markdown-body blockquote { @apply border-l-4 border-cream-border pl-4 italic text-brown-light my-2; } +.dark .markdown-body blockquote { + @apply border-ink-border text-ink-text-soft; +} .markdown-body a { @apply text-chutney-green underline hover:text-gold; } /* Scrollbar tuning */ .nice-scrollbar::-webkit-scrollbar { width: 6px; height: 6px; } .nice-scrollbar::-webkit-scrollbar-thumb { background: #e0d5c0; border-radius: 3px; } .nice-scrollbar::-webkit-scrollbar-thumb:hover { background: var(--warm-grey); } +.dark .nice-scrollbar::-webkit-scrollbar-thumb { background: #2a2a2e; } +.dark .nice-scrollbar::-webkit-scrollbar-thumb:hover { background: #3a3a40; } /* Highlight.js minimal tweaks */ .hljs { background: transparent; } + +/* Soft serif optical sizing for display headlines */ +.font-display { + font-optical-sizing: auto; +} diff --git a/services/frontend/app/layout.tsx b/services/frontend/app/layout.tsx index 2f86f67e..cb0ea9e7 100644 --- a/services/frontend/app/layout.tsx +++ b/services/frontend/app/layout.tsx @@ -1,5 +1,5 @@ import type { Metadata, Viewport } from 'next'; -import { Baloo_2, Great_Vibes, Caveat, Inter } from 'next/font/google'; +import { Baloo_2, Great_Vibes, Caveat, Inter, Fraunces } from 'next/font/google'; import './globals.css'; const baloo = Baloo_2({ @@ -29,9 +29,17 @@ const inter = Inter({ display: 'swap', }); +const fraunces = Fraunces({ + subsets: ['latin'], + weight: ['400', '500', '600', '700'], + variable: '--font-fraunces', + display: 'swap', +}); + export const metadata: Metadata = { title: 'समोसाचाट — samosaChaat', - description: 'Crafted with care. For India, from India. A warm, desi-flavored chat experience powered by nanochat.', + description: + 'Crafted with care. For India, from India. A warm, desi-flavored chat experience powered by nanochat.', icons: { icon: '/logo.svg' }, }; @@ -42,10 +50,25 @@ export const viewport: Viewport = { viewportFit: 'cover', }; +// Set theme class before paint to avoid flash +const themeInitScript = ` +(function(){try{ + var t=localStorage.getItem('theme'); + if(t==='dark'){document.documentElement.classList.add('dark');} + else if(t==='light'){document.documentElement.classList.remove('dark');} +}catch(e){}})(); +`; + export default function RootLayout({ children }: { children: React.ReactNode }) { return ( - - + + +