mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-12 19:00:14 +00:00
Merge pull request #29 from manmohan659/feat/modal-inference
feat(modal): Modal GPU inference endpoint
This commit is contained in:
commit
95b1ffc0fd
202
modal/_model.py
Normal file
202
modal/_model.py
Normal file
|
|
@ -0,0 +1,202 @@
|
|||
"""
|
||||
Minimal standalone GPT model for Modal inference.
|
||||
Extracted from nanochat/gpt.py — only the forward-pass code needed for inference.
|
||||
No training, no DDP, no flash_attention dependency.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
sequence_len: int = 2048
|
||||
vocab_size: int = 65536
|
||||
n_layer: int = 20
|
||||
n_head: int = 10
|
||||
n_kv_head: int = 10
|
||||
n_embd: int = 1280
|
||||
window_pattern: str = "L"
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_seq_len=2048):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
def forward(self, x, offset=0):
|
||||
seq_len = x.shape[-2]
|
||||
t = torch.arange(offset, offset + seq_len, device=x.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
return emb.cos(), emb.sin()
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin):
|
||||
def rotate_half(x):
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
def __init__(self, config: GPTConfig, use_v_emb: bool = False):
|
||||
super().__init__()
|
||||
self.n_head = config.n_head
|
||||
self.n_kv_head = config.n_kv_head
|
||||
self.head_dim = config.n_embd // config.n_head
|
||||
self.n_embd = config.n_embd
|
||||
self.use_v_emb = use_v_emb
|
||||
|
||||
self.c_q = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=False)
|
||||
self.c_k = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_v = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_proj = nn.Linear(config.n_head * self.head_dim, config.n_embd, bias=False)
|
||||
|
||||
self.q_norm = RMSNorm(self.head_dim)
|
||||
self.k_norm = RMSNorm(self.head_dim)
|
||||
|
||||
if use_v_emb:
|
||||
self.v_emb = nn.Parameter(torch.zeros(1, config.n_kv_head, config.sequence_len, self.head_dim))
|
||||
|
||||
self.rotary = RotaryEmbedding(self.head_dim, config.sequence_len)
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size()
|
||||
|
||||
q = self.c_q(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
||||
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
|
||||
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
|
||||
|
||||
# QK norm
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Rotary embeddings
|
||||
cos, sin = self.rotary(q)
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||
|
||||
# GQA: repeat k,v if n_kv_head < n_head
|
||||
if self.n_kv_head < self.n_head:
|
||||
rep = self.n_head // self.n_kv_head
|
||||
k = k.repeat_interleave(rep, dim=1)
|
||||
v = v.repeat_interleave(rep, dim=1)
|
||||
|
||||
# Value embeddings (if enabled)
|
||||
if self.use_v_emb:
|
||||
v = v + self.v_emb[:, :, :T, :]
|
||||
|
||||
# Scaled dot-product attention (PyTorch native, causal)
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
||||
return self.c_proj(y)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: GPTConfig, gated: bool = False):
|
||||
super().__init__()
|
||||
self.gated = gated
|
||||
if gated:
|
||||
hidden = int(config.n_embd * 8 / 3)
|
||||
hidden = ((hidden + 63) // 64) * 64
|
||||
self.c_fc = nn.Linear(config.n_embd, hidden, bias=False)
|
||||
self.c_fc2 = nn.Linear(config.n_embd, hidden, bias=False)
|
||||
self.c_proj = nn.Linear(hidden, config.n_embd, bias=False)
|
||||
else:
|
||||
hidden = 4 * config.n_embd
|
||||
self.c_fc = nn.Linear(config.n_embd, hidden, bias=False)
|
||||
self.c_proj = nn.Linear(hidden, config.n_embd, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
if self.gated:
|
||||
a = self.c_fc(x)
|
||||
b = self.c_fc2(x)
|
||||
return self.c_proj(F.relu(a).pow(2) * b)
|
||||
else:
|
||||
return self.c_proj(F.relu(self.c_fc(x)).pow(2))
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, config: GPTConfig, layer_idx: int, gated_mlp: bool = False, use_v_emb: bool = False):
|
||||
super().__init__()
|
||||
self.ln_1 = RMSNorm(config.n_embd)
|
||||
self.attn = CausalSelfAttention(config, use_v_emb=use_v_emb)
|
||||
self.ln_2 = RMSNorm(config.n_embd)
|
||||
self.mlp = MLP(config, gated=gated_mlp)
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(self, x, resid_lambda=1.0, x0_lambda=0.0, x0=None):
|
||||
h = x * resid_lambda + self.attn(self.ln_1(x))
|
||||
if x0 is not None and x0_lambda != 0.0:
|
||||
h = h + x0_lambda * x0
|
||||
h2 = h * resid_lambda + self.mlp(self.ln_2(h))
|
||||
if x0 is not None and x0_lambda != 0.0:
|
||||
h2 = h2 + x0_lambda * x0
|
||||
return h2
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
def __init__(self, config: GPTConfig, gated_mlp: bool = False, use_v_emb: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.transformer = nn.ModuleDict(dict(
|
||||
wte=nn.Embedding(config.vocab_size, config.n_embd),
|
||||
norm_emb=RMSNorm(config.n_embd),
|
||||
h=nn.ModuleList([Block(config, i, gated_mlp=gated_mlp, use_v_emb=use_v_emb) for i in range(config.n_layer)]),
|
||||
ln_f=RMSNorm(config.n_embd),
|
||||
))
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
||||
# Residual lambdas (per-layer scaling)
|
||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer))
|
||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer))
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, config: GPTConfig, state_dict: dict):
|
||||
"""Auto-detect architecture features from checkpoint keys."""
|
||||
gated = any("c_fc2" in k for k in state_dict)
|
||||
v_emb = any("v_emb" in k for k in state_dict)
|
||||
model = cls(config, gated_mlp=gated, use_v_emb=v_emb)
|
||||
return model
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize rotary embeddings and value embeddings."""
|
||||
for module in self.modules():
|
||||
if isinstance(module, RotaryEmbedding):
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, module.inv_freq.shape[0] * 2, 2).float() / (module.inv_freq.shape[0] * 2)))
|
||||
module.inv_freq.copy_(inv_freq)
|
||||
|
||||
def forward(self, idx):
|
||||
B, T = idx.size()
|
||||
assert T <= self.config.sequence_len, f"Input length {T} exceeds max {self.config.sequence_len}"
|
||||
|
||||
x = self.transformer.wte(idx)
|
||||
x = self.transformer.norm_emb(x)
|
||||
x0 = x # save for residual connections
|
||||
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
rl = self.resid_lambdas[i].item()
|
||||
xl = self.x0_lambdas[i].item()
|
||||
x = block(x, resid_lambda=rl, x0_lambda=xl, x0=x0)
|
||||
|
||||
x = self.transformer.ln_f(x)
|
||||
logits = self.lm_head(x)
|
||||
return logits
|
||||
79
modal/_tokenizer.py
Normal file
79
modal/_tokenizer.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
"""
|
||||
Minimal standalone tokenizer for Modal inference.
|
||||
Uses tiktoken for fast encoding/decoding with nanochat's special tokens.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import tiktoken
|
||||
|
||||
|
||||
# nanochat special tokens
|
||||
SPECIAL_TOKENS = {
|
||||
"<|bos|>": 0,
|
||||
"<|user_start|>": 1,
|
||||
"<|user_end|>": 2,
|
||||
"<|assistant_start|>": 3,
|
||||
"<|assistant_end|>": 4,
|
||||
"<|python_start|>": 5,
|
||||
"<|python_end|>": 6,
|
||||
"<|output_start|>": 7,
|
||||
"<|output_end|>": 8,
|
||||
}
|
||||
|
||||
# GPT-4 split pattern
|
||||
SPLIT_PATTERN = r"(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"
|
||||
|
||||
|
||||
class NanochatTokenizer:
|
||||
def __init__(self, model_dir: str):
|
||||
token_bytes_path = os.path.join(model_dir, "token_bytes.pt")
|
||||
tokenizer_pkl_path = os.path.join(model_dir, "tokenizer.pkl")
|
||||
|
||||
if os.path.exists(tokenizer_pkl_path):
|
||||
with open(tokenizer_pkl_path, "rb") as f:
|
||||
loaded = pickle.load(f)
|
||||
# Handle different pickle formats
|
||||
if isinstance(loaded, dict):
|
||||
mergeable_ranks = loaded
|
||||
elif hasattr(loaded, '_mergeable_ranks'):
|
||||
# It's a tiktoken Encoding object
|
||||
mergeable_ranks = loaded._mergeable_ranks
|
||||
else:
|
||||
# Try to use it as a pre-built encoder
|
||||
self._enc = loaded
|
||||
return
|
||||
elif os.path.exists(token_bytes_path):
|
||||
import torch
|
||||
token_bytes = torch.load(token_bytes_path, weights_only=True)
|
||||
mergeable_ranks = {bytes(token_bytes[i].tolist()): i for i in range(len(token_bytes))}
|
||||
else:
|
||||
from huggingface_hub import hf_hub_download
|
||||
path = hf_hub_download("karpathy/nanochat-d32", "tokenizer.pkl")
|
||||
with open(path, "rb") as f:
|
||||
mergeable_ranks = pickle.load(f)
|
||||
|
||||
self._enc = tiktoken.Encoding(
|
||||
name="nanochat",
|
||||
pat_str=SPLIT_PATTERN,
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens=SPECIAL_TOKENS,
|
||||
)
|
||||
|
||||
def encode(self, text: str) -> list[int]:
|
||||
return self._enc.encode(text, allowed_special=set())
|
||||
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
return self._enc.decode(tokens)
|
||||
|
||||
def encode_special(self, token_name: str) -> list[int]:
|
||||
return self._enc.encode(token_name, allowed_special="all")
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
return self._enc.n_vocab
|
||||
|
||||
|
||||
def get_tokenizer(model_dir: str | None = None) -> NanochatTokenizer:
|
||||
if model_dir is None:
|
||||
model_dir = "/weights/d20"
|
||||
return NanochatTokenizer(model_dir)
|
||||
273
modal/serve.py
Normal file
273
modal/serve.py
Normal file
|
|
@ -0,0 +1,273 @@
|
|||
"""
|
||||
samosaChaat — Modal GPU inference endpoint.
|
||||
|
||||
Downloads nanochat model weights from HuggingFace into a Modal Volume,
|
||||
loads them on a GPU, and exposes an SSE streaming endpoint compatible
|
||||
with the samosaChaat chat-api service.
|
||||
|
||||
Deploy: modal deploy modal/serve.py
|
||||
Dev: modal serve modal/serve.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import modal
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
MODEL_REPO = "nanochat-students/base-d20" # 1 GB, native nanochat format
|
||||
MODEL_PT = "model_021400.pt"
|
||||
META_JSON = "meta_021400.json"
|
||||
MODEL_TAG = "d20"
|
||||
GPU_TYPE = "T4" # cheapest, 16 GB VRAM — plenty for 1 GB model
|
||||
VOLUME_NAME = "samosachaat-weights"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Modal app + image
|
||||
# ---------------------------------------------------------------------------
|
||||
app = modal.App("samosachaat-inference")
|
||||
|
||||
# Build the container image with all dependencies
|
||||
inference_image = (
|
||||
modal.Image.debian_slim(python_version="3.12")
|
||||
.pip_install(
|
||||
"torch==2.5.1",
|
||||
"tiktoken>=0.11.0",
|
||||
"tokenizers>=0.22.0",
|
||||
"huggingface_hub>=0.25.0",
|
||||
"fastapi>=0.115.0",
|
||||
"uvicorn>=0.30.0",
|
||||
extra_index_url="https://download.pytorch.org/whl/cu124",
|
||||
)
|
||||
.add_local_file("modal/_model.py", "/root/_model.py")
|
||||
.add_local_file("modal/_tokenizer.py", "/root/_tokenizer.py")
|
||||
)
|
||||
|
||||
# Persistent volume for model weights
|
||||
volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Download weights into the volume (runs once)
|
||||
# ---------------------------------------------------------------------------
|
||||
@app.function(
|
||||
image=inference_image,
|
||||
volumes={"/weights": volume},
|
||||
timeout=600,
|
||||
)
|
||||
def download_weights():
|
||||
"""Download model weights from HuggingFace into the Modal volume."""
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
model_dir = f"/weights/{MODEL_TAG}"
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
for filename in [MODEL_PT, META_JSON, "token_bytes.pt", "tokenizer.pkl"]:
|
||||
dest = os.path.join(model_dir, filename)
|
||||
if os.path.exists(dest):
|
||||
print(f" Already exists: {dest}")
|
||||
continue
|
||||
print(f" Downloading {filename} from {MODEL_REPO}...")
|
||||
path = hf_hub_download(MODEL_REPO, filename)
|
||||
# Copy to volume
|
||||
import shutil
|
||||
shutil.copy2(path, dest)
|
||||
print(f" Saved to {dest}")
|
||||
|
||||
volume.commit()
|
||||
print("Weights downloaded and committed to volume.")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inference class — GPU singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
@app.cls(
|
||||
image=inference_image,
|
||||
volumes={"/weights": volume},
|
||||
gpu=GPU_TYPE,
|
||||
scaledown_window=300, # keep warm for 5 min after last request
|
||||
# concurrency handled by @modal.concurrent below
|
||||
timeout=120,
|
||||
)
|
||||
class Inference:
|
||||
model: object
|
||||
tokenizer: object
|
||||
engine: object
|
||||
device: object
|
||||
|
||||
@modal.enter()
|
||||
def load_model(self):
|
||||
"""Called once when the container starts — loads model onto GPU."""
|
||||
import torch
|
||||
import sys
|
||||
|
||||
# Add the nanochat engine code path
|
||||
# We inline the minimal loading logic here to avoid importing the full
|
||||
# nanochat package (which has heavy deps we don't need on Modal).
|
||||
print("Loading model...")
|
||||
t0 = time.time()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.device = device
|
||||
|
||||
model_dir = f"/weights/{MODEL_TAG}"
|
||||
meta_path = os.path.join(model_dir, META_JSON)
|
||||
model_path = os.path.join(model_dir, MODEL_PT)
|
||||
|
||||
# Load meta
|
||||
with open(meta_path) as f:
|
||||
meta = json.load(f)
|
||||
model_config = meta if "model_config" not in meta else meta["model_config"]
|
||||
|
||||
# Patch missing config keys
|
||||
model_config.setdefault("window_pattern", "L")
|
||||
|
||||
print(f" Config: {model_config}")
|
||||
|
||||
# Build model
|
||||
# We need the GPT class — download it from the repo itself
|
||||
# For simplicity, we define a minimal inline version that matches nanochat
|
||||
from _model import GPT, GPTConfig
|
||||
|
||||
config = GPTConfig(**model_config)
|
||||
model_data = torch.load(model_path, map_location=device, weights_only=False)
|
||||
|
||||
# Fix torch compile prefix
|
||||
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
||||
|
||||
# Patch missing keys
|
||||
n_layer = config.n_layer
|
||||
if "resid_lambdas" not in model_data:
|
||||
model_data["resid_lambdas"] = torch.ones(n_layer)
|
||||
if "x0_lambdas" not in model_data:
|
||||
model_data["x0_lambdas"] = torch.zeros(n_layer)
|
||||
|
||||
# Auto-detect architecture from checkpoint
|
||||
# Convert bfloat16 weights to float32 for compatibility
|
||||
model_data = {
|
||||
k: v.float() if v.dtype == torch.bfloat16 else v
|
||||
for k, v in model_data.items()
|
||||
}
|
||||
|
||||
# Auto-detect architecture from checkpoint
|
||||
model = GPT.from_state_dict(config, model_data)
|
||||
model.to(device)
|
||||
model.init_weights()
|
||||
model.load_state_dict(model_data, strict=True, assign=True)
|
||||
model.eval()
|
||||
|
||||
self.model = model
|
||||
self.config = config
|
||||
|
||||
# Load tokenizer
|
||||
from _tokenizer import get_tokenizer
|
||||
self.tokenizer = get_tokenizer(model_dir)
|
||||
|
||||
dt = time.time() - t0
|
||||
print(f"Model loaded in {dt:.1f}s on {device}")
|
||||
|
||||
@modal.fastapi_endpoint(method="POST", docs=True)
|
||||
async def generate(self, request: dict):
|
||||
"""
|
||||
Streaming chat endpoint — SSE compatible with samosaChaat format.
|
||||
|
||||
Input: {"messages": [{"role": "user", "content": "..."}], "temperature": 0.8, "max_tokens": 512, "top_k": 50}
|
||||
Output: SSE stream of data: {"token": "...", "gpu": 0} then data: {"done": true}
|
||||
"""
|
||||
import torch
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
messages = request.get("messages", [])
|
||||
temperature = min(max(request.get("temperature", 0.8), 0.0), 2.0)
|
||||
max_tokens = min(max(request.get("max_tokens", 512), 1), 2048)
|
||||
top_k = min(max(request.get("top_k", 50), 0), 200)
|
||||
|
||||
# Build token sequence from messages
|
||||
tokens = []
|
||||
bos = self.tokenizer.encode_special("<|bos|>")
|
||||
user_start = self.tokenizer.encode_special("<|user_start|>")
|
||||
user_end = self.tokenizer.encode_special("<|user_end|>")
|
||||
assistant_start = self.tokenizer.encode_special("<|assistant_start|>")
|
||||
assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
|
||||
|
||||
tokens.extend(bos)
|
||||
for msg in messages:
|
||||
if msg["role"] == "user":
|
||||
tokens.extend(user_start)
|
||||
tokens.extend(self.tokenizer.encode(msg["content"]))
|
||||
tokens.extend(user_end)
|
||||
elif msg["role"] == "assistant":
|
||||
tokens.extend(assistant_start)
|
||||
tokens.extend(self.tokenizer.encode(msg["content"]))
|
||||
tokens.extend(assistant_end)
|
||||
# Prompt the model to generate an assistant response
|
||||
tokens.extend(assistant_start)
|
||||
|
||||
# Truncate to fit context
|
||||
max_context = self.config.sequence_len - max_tokens
|
||||
if len(tokens) > max_context:
|
||||
tokens = tokens[-max_context:]
|
||||
|
||||
async def stream():
|
||||
input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated = []
|
||||
for _ in range(max_tokens):
|
||||
# Forward pass
|
||||
logits = self.model(input_ids)
|
||||
next_logits = logits[:, -1, :]
|
||||
|
||||
# Temperature
|
||||
if temperature > 0:
|
||||
next_logits = next_logits / temperature
|
||||
|
||||
# Top-k filtering
|
||||
if top_k > 0:
|
||||
v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
|
||||
next_logits[next_logits < v[:, [-1]]] = float('-inf')
|
||||
|
||||
# Sample
|
||||
probs = torch.softmax(next_logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
token_id = next_token.item()
|
||||
|
||||
# Check for stop tokens
|
||||
if token_id in [t[0] for t in [assistant_end, bos]]:
|
||||
break
|
||||
|
||||
# Decode and yield
|
||||
token_text = self.tokenizer.decode([token_id])
|
||||
yield f"data: {json.dumps({'token': token_text, 'gpu': 0})}\n\n"
|
||||
|
||||
# Append for next iteration
|
||||
input_ids = torch.cat([input_ids, next_token], dim=1)
|
||||
|
||||
# Truncate if exceeding sequence length
|
||||
if input_ids.size(1) > self.config.sequence_len:
|
||||
input_ids = input_ids[:, -self.config.sequence_len:]
|
||||
|
||||
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
@modal.fastapi_endpoint(method="GET", docs=True)
|
||||
def health(self):
|
||||
return {
|
||||
"status": "ok",
|
||||
"model": MODEL_TAG,
|
||||
"gpu": GPU_TYPE,
|
||||
"ready": self.model is not None,
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user