mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 21:25:21 +00:00
Merge ac5927e158 into 5019accc5b
This commit is contained in:
commit
d85c01fa64
|
|
@ -61,12 +61,12 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data,
|
|||
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
|
||||
# Load the model state
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||
model_data = torch.load(model_path, map_location=device)
|
||||
model_data = torch.load(model_path, map_location=device, weights_only=True)
|
||||
# Load the optimizer state if requested
|
||||
optimizer_data = None
|
||||
if load_optimizer:
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device, weights_only=True)
|
||||
# Load the metadata
|
||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||
with open(meta_path, "r", encoding="utf-8") as f:
|
||||
|
|
@ -190,5 +190,5 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None):
|
|||
log0(f"Optimizer checkpoint not found: {optimizer_path}")
|
||||
return None
|
||||
log0(f"Loading optimizer state from {optimizer_path}")
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device, weights_only=True)
|
||||
return optimizer_data
|
||||
|
|
|
|||
|
|
@ -82,7 +82,13 @@ def download_file_with_lock(url, filename, postprocess_fn=None):
|
|||
"""
|
||||
Downloads a file from a URL to a local path in the base directory.
|
||||
Uses a lock file to prevent concurrent downloads among multiple ranks.
|
||||
Only HTTPS URLs are accepted to prevent unencrypted transfers.
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme != "https":
|
||||
raise ValueError(f"Only HTTPS URLs are allowed for downloads, got: {url!r}")
|
||||
|
||||
base_dir = get_base_dir()
|
||||
file_path = os.path.join(base_dir, filename)
|
||||
lock_path = file_path + ".lock"
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ For details of how the dataset was prepared, see `repackage_data_reference.py`.
|
|||
import os
|
||||
import argparse
|
||||
import time
|
||||
import tempfile
|
||||
import requests
|
||||
import pyarrow.parquet as pq
|
||||
from multiprocessing import Pool
|
||||
|
|
@ -101,21 +102,23 @@ def download_single_file(index):
|
|||
try:
|
||||
response = requests.get(url, stream=True, timeout=30)
|
||||
response.raise_for_status()
|
||||
# Write to temporary file first
|
||||
temp_path = filepath + f".tmp"
|
||||
with open(temp_path, 'wb') as f:
|
||||
# Write to a unique temporary file to avoid races between concurrent workers
|
||||
with tempfile.NamedTemporaryFile(
|
||||
dir=os.path.dirname(filepath), delete=False, suffix=".tmp"
|
||||
) as tmp_f:
|
||||
temp_path = tmp_f.name
|
||||
for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
# Move temp file to final location
|
||||
os.rename(temp_path, filepath)
|
||||
tmp_f.write(chunk)
|
||||
# Atomically move temp file to final location
|
||||
os.replace(temp_path, filepath)
|
||||
print(f"Successfully downloaded {filename}")
|
||||
return True
|
||||
|
||||
except (requests.RequestException, IOError) as e:
|
||||
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
|
||||
# Clean up any partial files
|
||||
for path in [filepath + f".tmp", filepath]:
|
||||
for path in [temp_path if 'temp_path' in dir() else filepath + ".tmp", filepath]:
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
os.remove(path)
|
||||
|
|
|
|||
|
|
@ -144,12 +144,16 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
|
|||
with caution.
|
||||
"""
|
||||
|
||||
if platform.uname().system != "Darwin":
|
||||
# These resource limit calls seem to fail on macOS (Darwin), skip?
|
||||
if maximum_memory_bytes is not None:
|
||||
import resource
|
||||
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
system = platform.uname().system
|
||||
if system == "Darwin":
|
||||
# RLIMIT_AS is unsupported on macOS; use RLIMIT_DATA instead
|
||||
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
elif system == "Linux":
|
||||
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
|
||||
faulthandler.disable()
|
||||
|
||||
|
|
|
|||
|
|
@ -449,7 +449,11 @@
|
|||
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
|
||||
const errDiv = document.createElement('div');
|
||||
errDiv.className = 'error-message';
|
||||
errDiv.textContent = `Error: ${error.message}`;
|
||||
assistantContent.innerHTML = '';
|
||||
assistantContent.appendChild(errDiv);
|
||||
} finally {
|
||||
isGenerating = false;
|
||||
sendButton.disabled = !chatInput.value.trim();
|
||||
|
|
|
|||
|
|
@ -28,19 +28,29 @@ Abuse Prevention:
|
|||
- Temperature clamped to 0.0-2.0
|
||||
- Top-k clamped to 0-200 (0 disables top-k filtering, using full vocabulary)
|
||||
- Max tokens clamped to 1-4096
|
||||
- Sliding-window rate limiting per client IP (default: 10 req/60s)
|
||||
|
||||
Security Environment Variables:
|
||||
NANOCHAT_RATE_LIMIT Max requests per window per IP (default: 10)
|
||||
NANOCHAT_RATE_WINDOW Rate limit window in seconds (default: 60)
|
||||
NANOCHAT_STATS_KEY If set, /health and /stats require this value in X-Stats-Key header
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import secrets
|
||||
import torch
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Request, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -59,6 +69,35 @@ MAX_TOP_K = 200
|
|||
MIN_MAX_TOKENS = 1
|
||||
MAX_MAX_TOKENS = 4096
|
||||
|
||||
# Rate limiting: sliding-window per IP, no extra dependencies required
|
||||
RATE_LIMIT_REQUESTS = int(os.environ.get("NANOCHAT_RATE_LIMIT", "10")) # max requests
|
||||
RATE_LIMIT_WINDOW = int(os.environ.get("NANOCHAT_RATE_WINDOW", "60")) # per N seconds
|
||||
_rate_limit_store: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
async def check_rate_limit(request: Request) -> None:
|
||||
"""Sliding-window rate limiter keyed on client IP. Raises HTTP 429 when exceeded."""
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
now = time.monotonic()
|
||||
window_start = now - RATE_LIMIT_WINDOW
|
||||
timestamps = _rate_limit_store[client_ip]
|
||||
# Purge timestamps outside the current window
|
||||
_rate_limit_store[client_ip] = [t for t in timestamps if t > window_start]
|
||||
if len(_rate_limit_store[client_ip]) >= RATE_LIMIT_REQUESTS:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Rate limit exceeded: max {RATE_LIMIT_REQUESTS} requests per {RATE_LIMIT_WINDOW}s."
|
||||
)
|
||||
_rate_limit_store[client_ip].append(now)
|
||||
|
||||
# Optional API key guard for /health and /stats (set NANOCHAT_STATS_KEY env var to enable)
|
||||
_STATS_KEY = os.environ.get("NANOCHAT_STATS_KEY", "")
|
||||
_stats_key_header = APIKeyHeader(name="X-Stats-Key", auto_error=False)
|
||||
|
||||
async def check_stats_key(key: Optional[str] = Depends(_stats_key_header)) -> None:
|
||||
"""If NANOCHAT_STATS_KEY is set, require it on /health and /stats requests."""
|
||||
if _STATS_KEY and not secrets.compare_digest(key or "", _STATS_KEY):
|
||||
raise HTTPException(status_code=403, detail="Forbidden: invalid or missing X-Stats-Key header.")
|
||||
|
||||
parser = argparse.ArgumentParser(description='NanoChat Web Server')
|
||||
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
|
||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
|
||||
|
|
@ -186,7 +225,7 @@ def validate_chat_request(request: ChatRequest):
|
|||
if message.role not in ["user", "assistant"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
|
||||
detail=f"Message {i} has invalid role '{message.role}'. Must be 'user' or 'assistant'."
|
||||
)
|
||||
|
||||
# Validate temperature
|
||||
|
|
@ -227,9 +266,9 @@ app = FastAPI(lifespan=lifespan)
|
|||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
allow_credentials=False, # credentials require explicit origin allowlist, not wildcard
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["Content-Type", "X-Stats-Key"],
|
||||
)
|
||||
|
||||
@app.get("/")
|
||||
|
|
@ -303,16 +342,18 @@ async def generate_stream(
|
|||
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||
|
||||
@app.post("/chat/completions")
|
||||
async def chat_completions(request: ChatRequest):
|
||||
async def chat_completions(request: Request, chat_request: ChatRequest, _: None = Depends(check_rate_limit)):
|
||||
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
|
||||
|
||||
# Basic validation to prevent abuse
|
||||
validate_chat_request(request)
|
||||
validate_chat_request(chat_request)
|
||||
|
||||
# Log incoming conversation to console
|
||||
# Log incoming conversation to console (truncated for privacy)
|
||||
logger.info("="*20)
|
||||
for i, message in enumerate(request.messages):
|
||||
logger.info(f"[{message.role.upper()}]: {message.content}")
|
||||
for i, message in enumerate(chat_request.messages):
|
||||
preview = message.content[:120] + ("…" if len(message.content) > 120 else "")
|
||||
logger.debug(f"[{message.role.upper()}]: {preview}")
|
||||
logger.info(f"Received request with {len(chat_request.messages)} message(s)")
|
||||
logger.info("-"*20)
|
||||
|
||||
# Acquire a worker from the pool (will wait if all are busy)
|
||||
|
|
@ -328,7 +369,7 @@ async def chat_completions(request: ChatRequest):
|
|||
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
|
||||
|
||||
conversation_tokens = [bos]
|
||||
for message in request.messages:
|
||||
for message in chat_request.messages:
|
||||
if message.role == "user":
|
||||
conversation_tokens.append(user_start)
|
||||
conversation_tokens.extend(worker.tokenizer.encode(message.content))
|
||||
|
|
@ -347,9 +388,9 @@ async def chat_completions(request: ChatRequest):
|
|||
async for chunk in generate_stream(
|
||||
worker,
|
||||
conversation_tokens,
|
||||
temperature=request.temperature,
|
||||
max_new_tokens=request.max_tokens,
|
||||
top_k=request.top_k
|
||||
temperature=chat_request.temperature,
|
||||
max_new_tokens=chat_request.max_tokens,
|
||||
top_k=chat_request.top_k
|
||||
):
|
||||
# Accumulate response for logging
|
||||
chunk_data = json.loads(chunk.replace("data: ", "").strip())
|
||||
|
|
@ -357,9 +398,9 @@ async def chat_completions(request: ChatRequest):
|
|||
response_tokens.append(chunk_data["token"])
|
||||
yield chunk
|
||||
finally:
|
||||
# Log the assistant response to console
|
||||
# Log response length only (not content) to avoid storing user data in logs
|
||||
full_response = "".join(response_tokens)
|
||||
logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {full_response}")
|
||||
logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {len(full_response)} chars generated")
|
||||
logger.info("="*20)
|
||||
# Release worker back to pool after streaming is done
|
||||
await worker_pool.release_worker(worker)
|
||||
|
|
@ -374,8 +415,8 @@ async def chat_completions(request: ChatRequest):
|
|||
raise e
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
async def health(_: None = Depends(check_stats_key)):
|
||||
"""Health check endpoint. Protected by X-Stats-Key header when NANOCHAT_STATS_KEY is set."""
|
||||
worker_pool = getattr(app.state, 'worker_pool', None)
|
||||
return {
|
||||
"status": "ok",
|
||||
|
|
@ -385,8 +426,8 @@ async def health():
|
|||
}
|
||||
|
||||
@app.get("/stats")
|
||||
async def stats():
|
||||
"""Get worker pool statistics."""
|
||||
async def stats(_: None = Depends(check_stats_key)):
|
||||
"""Get worker pool statistics. Protected by X-Stats-Key header when NANOCHAT_STATS_KEY is set."""
|
||||
worker_pool = app.state.worker_pool
|
||||
return {
|
||||
"total_workers": len(worker_pool.workers),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user