This commit is contained in:
Santhosh Kumar Ravindran 2026-03-18 18:16:34 -07:00 committed by GitHub
commit d85c01fa64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 94 additions and 36 deletions

View File

@ -61,12 +61,12 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data,
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
# Load the model state
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
model_data = torch.load(model_path, map_location=device)
model_data = torch.load(model_path, map_location=device, weights_only=True)
# Load the optimizer state if requested
optimizer_data = None
if load_optimizer:
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
optimizer_data = torch.load(optimizer_path, map_location=device)
optimizer_data = torch.load(optimizer_path, map_location=device, weights_only=True)
# Load the metadata
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
with open(meta_path, "r", encoding="utf-8") as f:
@ -190,5 +190,5 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None):
log0(f"Optimizer checkpoint not found: {optimizer_path}")
return None
log0(f"Loading optimizer state from {optimizer_path}")
optimizer_data = torch.load(optimizer_path, map_location=device)
optimizer_data = torch.load(optimizer_path, map_location=device, weights_only=True)
return optimizer_data

View File

@ -82,7 +82,13 @@ def download_file_with_lock(url, filename, postprocess_fn=None):
"""
Downloads a file from a URL to a local path in the base directory.
Uses a lock file to prevent concurrent downloads among multiple ranks.
Only HTTPS URLs are accepted to prevent unencrypted transfers.
"""
from urllib.parse import urlparse
parsed = urlparse(url)
if parsed.scheme != "https":
raise ValueError(f"Only HTTPS URLs are allowed for downloads, got: {url!r}")
base_dir = get_base_dir()
file_path = os.path.join(base_dir, filename)
lock_path = file_path + ".lock"

View File

@ -10,6 +10,7 @@ For details of how the dataset was prepared, see `repackage_data_reference.py`.
import os
import argparse
import time
import tempfile
import requests
import pyarrow.parquet as pq
from multiprocessing import Pool
@ -101,21 +102,23 @@ def download_single_file(index):
try:
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
# Write to temporary file first
temp_path = filepath + f".tmp"
with open(temp_path, 'wb') as f:
# Write to a unique temporary file to avoid races between concurrent workers
with tempfile.NamedTemporaryFile(
dir=os.path.dirname(filepath), delete=False, suffix=".tmp"
) as tmp_f:
temp_path = tmp_f.name
for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
if chunk:
f.write(chunk)
# Move temp file to final location
os.rename(temp_path, filepath)
tmp_f.write(chunk)
# Atomically move temp file to final location
os.replace(temp_path, filepath)
print(f"Successfully downloaded {filename}")
return True
except (requests.RequestException, IOError) as e:
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
# Clean up any partial files
for path in [filepath + f".tmp", filepath]:
for path in [temp_path if 'temp_path' in dir() else filepath + ".tmp", filepath]:
if os.path.exists(path):
try:
os.remove(path)

View File

@ -144,12 +144,16 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
with caution.
"""
if platform.uname().system != "Darwin":
# These resource limit calls seem to fail on macOS (Darwin), skip?
if maximum_memory_bytes is not None:
import resource
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
system = platform.uname().system
if system == "Darwin":
# RLIMIT_AS is unsupported on macOS; use RLIMIT_DATA instead
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
elif system == "Linux":
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
faulthandler.disable()

View File

@ -449,7 +449,11 @@
} catch (error) {
console.error('Error:', error);
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
const errDiv = document.createElement('div');
errDiv.className = 'error-message';
errDiv.textContent = `Error: ${error.message}`;
assistantContent.innerHTML = '';
assistantContent.appendChild(errDiv);
} finally {
isGenerating = false;
sendButton.disabled = !chatInput.value.trim();

View File

@ -28,19 +28,29 @@ Abuse Prevention:
- Temperature clamped to 0.0-2.0
- Top-k clamped to 0-200 (0 disables top-k filtering, using full vocabulary)
- Max tokens clamped to 1-4096
- Sliding-window rate limiting per client IP (default: 10 req/60s)
Security Environment Variables:
NANOCHAT_RATE_LIMIT Max requests per window per IP (default: 10)
NANOCHAT_RATE_WINDOW Rate limit window in seconds (default: 60)
NANOCHAT_STATS_KEY If set, /health and /stats require this value in X-Stats-Key header
"""
import argparse
import json
import os
import time
import secrets
import torch
import asyncio
import logging
import random
from collections import defaultdict
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator
from dataclasses import dataclass
@ -59,6 +69,35 @@ MAX_TOP_K = 200
MIN_MAX_TOKENS = 1
MAX_MAX_TOKENS = 4096
# Rate limiting: sliding-window per IP, no extra dependencies required
RATE_LIMIT_REQUESTS = int(os.environ.get("NANOCHAT_RATE_LIMIT", "10")) # max requests
RATE_LIMIT_WINDOW = int(os.environ.get("NANOCHAT_RATE_WINDOW", "60")) # per N seconds
_rate_limit_store: dict[str, list[float]] = defaultdict(list)
async def check_rate_limit(request: Request) -> None:
"""Sliding-window rate limiter keyed on client IP. Raises HTTP 429 when exceeded."""
client_ip = request.client.host if request.client else "unknown"
now = time.monotonic()
window_start = now - RATE_LIMIT_WINDOW
timestamps = _rate_limit_store[client_ip]
# Purge timestamps outside the current window
_rate_limit_store[client_ip] = [t for t in timestamps if t > window_start]
if len(_rate_limit_store[client_ip]) >= RATE_LIMIT_REQUESTS:
raise HTTPException(
status_code=429,
detail=f"Rate limit exceeded: max {RATE_LIMIT_REQUESTS} requests per {RATE_LIMIT_WINDOW}s."
)
_rate_limit_store[client_ip].append(now)
# Optional API key guard for /health and /stats (set NANOCHAT_STATS_KEY env var to enable)
_STATS_KEY = os.environ.get("NANOCHAT_STATS_KEY", "")
_stats_key_header = APIKeyHeader(name="X-Stats-Key", auto_error=False)
async def check_stats_key(key: Optional[str] = Depends(_stats_key_header)) -> None:
"""If NANOCHAT_STATS_KEY is set, require it on /health and /stats requests."""
if _STATS_KEY and not secrets.compare_digest(key or "", _STATS_KEY):
raise HTTPException(status_code=403, detail="Forbidden: invalid or missing X-Stats-Key header.")
parser = argparse.ArgumentParser(description='NanoChat Web Server')
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
@ -186,7 +225,7 @@ def validate_chat_request(request: ChatRequest):
if message.role not in ["user", "assistant"]:
raise HTTPException(
status_code=400,
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
detail=f"Message {i} has invalid role '{message.role}'. Must be 'user' or 'assistant'."
)
# Validate temperature
@ -227,9 +266,9 @@ app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=False, # credentials require explicit origin allowlist, not wildcard
allow_methods=["GET", "POST"],
allow_headers=["Content-Type", "X-Stats-Key"],
)
@app.get("/")
@ -303,16 +342,18 @@ async def generate_stream(
yield f"data: {json.dumps({'done': True})}\n\n"
@app.post("/chat/completions")
async def chat_completions(request: ChatRequest):
async def chat_completions(request: Request, chat_request: ChatRequest, _: None = Depends(check_rate_limit)):
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
# Basic validation to prevent abuse
validate_chat_request(request)
validate_chat_request(chat_request)
# Log incoming conversation to console
# Log incoming conversation to console (truncated for privacy)
logger.info("="*20)
for i, message in enumerate(request.messages):
logger.info(f"[{message.role.upper()}]: {message.content}")
for i, message in enumerate(chat_request.messages):
preview = message.content[:120] + ("" if len(message.content) > 120 else "")
logger.debug(f"[{message.role.upper()}]: {preview}")
logger.info(f"Received request with {len(chat_request.messages)} message(s)")
logger.info("-"*20)
# Acquire a worker from the pool (will wait if all are busy)
@ -328,7 +369,7 @@ async def chat_completions(request: ChatRequest):
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
conversation_tokens = [bos]
for message in request.messages:
for message in chat_request.messages:
if message.role == "user":
conversation_tokens.append(user_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
@ -347,9 +388,9 @@ async def chat_completions(request: ChatRequest):
async for chunk in generate_stream(
worker,
conversation_tokens,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
top_k=request.top_k
temperature=chat_request.temperature,
max_new_tokens=chat_request.max_tokens,
top_k=chat_request.top_k
):
# Accumulate response for logging
chunk_data = json.loads(chunk.replace("data: ", "").strip())
@ -357,9 +398,9 @@ async def chat_completions(request: ChatRequest):
response_tokens.append(chunk_data["token"])
yield chunk
finally:
# Log the assistant response to console
# Log response length only (not content) to avoid storing user data in logs
full_response = "".join(response_tokens)
logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {full_response}")
logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {len(full_response)} chars generated")
logger.info("="*20)
# Release worker back to pool after streaming is done
await worker_pool.release_worker(worker)
@ -374,8 +415,8 @@ async def chat_completions(request: ChatRequest):
raise e
@app.get("/health")
async def health():
"""Health check endpoint."""
async def health(_: None = Depends(check_stats_key)):
"""Health check endpoint. Protected by X-Stats-Key header when NANOCHAT_STATS_KEY is set."""
worker_pool = getattr(app.state, 'worker_pool', None)
return {
"status": "ok",
@ -385,8 +426,8 @@ async def health():
}
@app.get("/stats")
async def stats():
"""Get worker pool statistics."""
async def stats(_: None = Depends(check_stats_key)):
"""Get worker pool statistics. Protected by X-Stats-Key header when NANOCHAT_STATS_KEY is set."""
worker_pool = app.state.worker_pool
return {
"total_workers": len(worker_pool.workers),