diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index f71524e..407ad15 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -61,12 +61,12 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0): # Load the model state model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") - model_data = torch.load(model_path, map_location=device) + model_data = torch.load(model_path, map_location=device, weights_only=True) # Load the optimizer state if requested optimizer_data = None if load_optimizer: optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") - optimizer_data = torch.load(optimizer_path, map_location=device) + optimizer_data = torch.load(optimizer_path, map_location=device, weights_only=True) # Load the metadata meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") with open(meta_path, "r", encoding="utf-8") as f: @@ -190,5 +190,5 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None): log0(f"Optimizer checkpoint not found: {optimizer_path}") return None log0(f"Loading optimizer state from {optimizer_path}") - optimizer_data = torch.load(optimizer_path, map_location=device) + optimizer_data = torch.load(optimizer_path, map_location=device, weights_only=True) return optimizer_data diff --git a/nanochat/common.py b/nanochat/common.py index bd14fd2..0ab1a27 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -82,7 +82,13 @@ def download_file_with_lock(url, filename, postprocess_fn=None): """ Downloads a file from a URL to a local path in the base directory. Uses a lock file to prevent concurrent downloads among multiple ranks. + Only HTTPS URLs are accepted to prevent unencrypted transfers. """ + from urllib.parse import urlparse + parsed = urlparse(url) + if parsed.scheme != "https": + raise ValueError(f"Only HTTPS URLs are allowed for downloads, got: {url!r}") + base_dir = get_base_dir() file_path = os.path.join(base_dir, filename) lock_path = file_path + ".lock" diff --git a/nanochat/dataset.py b/nanochat/dataset.py index fffe722..034a448 100644 --- a/nanochat/dataset.py +++ b/nanochat/dataset.py @@ -10,6 +10,7 @@ For details of how the dataset was prepared, see `repackage_data_reference.py`. import os import argparse import time +import tempfile import requests import pyarrow.parquet as pq from multiprocessing import Pool @@ -101,21 +102,23 @@ def download_single_file(index): try: response = requests.get(url, stream=True, timeout=30) response.raise_for_status() - # Write to temporary file first - temp_path = filepath + f".tmp" - with open(temp_path, 'wb') as f: + # Write to a unique temporary file to avoid races between concurrent workers + with tempfile.NamedTemporaryFile( + dir=os.path.dirname(filepath), delete=False, suffix=".tmp" + ) as tmp_f: + temp_path = tmp_f.name for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks if chunk: - f.write(chunk) - # Move temp file to final location - os.rename(temp_path, filepath) + tmp_f.write(chunk) + # Atomically move temp file to final location + os.replace(temp_path, filepath) print(f"Successfully downloaded {filename}") return True except (requests.RequestException, IOError) as e: print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}") # Clean up any partial files - for path in [filepath + f".tmp", filepath]: + for path in [temp_path if 'temp_path' in dir() else filepath + ".tmp", filepath]: if os.path.exists(path): try: os.remove(path) diff --git a/nanochat/execution.py b/nanochat/execution.py index 6f50c74..521c845 100644 --- a/nanochat/execution.py +++ b/nanochat/execution.py @@ -144,12 +144,16 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None): with caution. """ - if platform.uname().system != "Darwin": - # These resource limit calls seem to fail on macOS (Darwin), skip? + if maximum_memory_bytes is not None: import resource - resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) - resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) - resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) + system = platform.uname().system + if system == "Darwin": + # RLIMIT_AS is unsupported on macOS; use RLIMIT_DATA instead + resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) + elif system == "Linux": + resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) + resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) + resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) faulthandler.disable() diff --git a/nanochat/ui.html b/nanochat/ui.html index b85532a..536b99b 100644 --- a/nanochat/ui.html +++ b/nanochat/ui.html @@ -449,7 +449,11 @@ } catch (error) { console.error('Error:', error); - assistantContent.innerHTML = `
Error: ${error.message}
`; + const errDiv = document.createElement('div'); + errDiv.className = 'error-message'; + errDiv.textContent = `Error: ${error.message}`; + assistantContent.innerHTML = ''; + assistantContent.appendChild(errDiv); } finally { isGenerating = false; sendButton.disabled = !chatInput.value.trim(); diff --git a/scripts/chat_web.py b/scripts/chat_web.py index ffaf7da..7b79cd5 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -28,19 +28,29 @@ Abuse Prevention: - Temperature clamped to 0.0-2.0 - Top-k clamped to 0-200 (0 disables top-k filtering, using full vocabulary) - Max tokens clamped to 1-4096 + - Sliding-window rate limiting per client IP (default: 10 req/60s) + +Security Environment Variables: + NANOCHAT_RATE_LIMIT Max requests per window per IP (default: 10) + NANOCHAT_RATE_WINDOW Rate limit window in seconds (default: 60) + NANOCHAT_STATS_KEY If set, /health and /stats require this value in X-Stats-Key header """ import argparse import json import os +import time +import secrets import torch import asyncio import logging import random +from collections import defaultdict from contextlib import asynccontextmanager -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Request, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse +from fastapi.security import APIKeyHeader from pydantic import BaseModel from typing import List, Optional, AsyncGenerator from dataclasses import dataclass @@ -59,6 +69,35 @@ MAX_TOP_K = 200 MIN_MAX_TOKENS = 1 MAX_MAX_TOKENS = 4096 +# Rate limiting: sliding-window per IP, no extra dependencies required +RATE_LIMIT_REQUESTS = int(os.environ.get("NANOCHAT_RATE_LIMIT", "10")) # max requests +RATE_LIMIT_WINDOW = int(os.environ.get("NANOCHAT_RATE_WINDOW", "60")) # per N seconds +_rate_limit_store: dict[str, list[float]] = defaultdict(list) + +async def check_rate_limit(request: Request) -> None: + """Sliding-window rate limiter keyed on client IP. Raises HTTP 429 when exceeded.""" + client_ip = request.client.host if request.client else "unknown" + now = time.monotonic() + window_start = now - RATE_LIMIT_WINDOW + timestamps = _rate_limit_store[client_ip] + # Purge timestamps outside the current window + _rate_limit_store[client_ip] = [t for t in timestamps if t > window_start] + if len(_rate_limit_store[client_ip]) >= RATE_LIMIT_REQUESTS: + raise HTTPException( + status_code=429, + detail=f"Rate limit exceeded: max {RATE_LIMIT_REQUESTS} requests per {RATE_LIMIT_WINDOW}s." + ) + _rate_limit_store[client_ip].append(now) + +# Optional API key guard for /health and /stats (set NANOCHAT_STATS_KEY env var to enable) +_STATS_KEY = os.environ.get("NANOCHAT_STATS_KEY", "") +_stats_key_header = APIKeyHeader(name="X-Stats-Key", auto_error=False) + +async def check_stats_key(key: Optional[str] = Depends(_stats_key_header)) -> None: + """If NANOCHAT_STATS_KEY is set, require it on /health and /stats requests.""" + if _STATS_KEY and not secrets.compare_digest(key or "", _STATS_KEY): + raise HTTPException(status_code=403, detail="Forbidden: invalid or missing X-Stats-Key header.") + parser = argparse.ArgumentParser(description='NanoChat Web Server') parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)') parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl") @@ -186,7 +225,7 @@ def validate_chat_request(request: ChatRequest): if message.role not in ["user", "assistant"]: raise HTTPException( status_code=400, - detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'" + detail=f"Message {i} has invalid role '{message.role}'. Must be 'user' or 'assistant'." ) # Validate temperature @@ -227,9 +266,9 @@ app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_credentials=False, # credentials require explicit origin allowlist, not wildcard + allow_methods=["GET", "POST"], + allow_headers=["Content-Type", "X-Stats-Key"], ) @app.get("/") @@ -303,16 +342,18 @@ async def generate_stream( yield f"data: {json.dumps({'done': True})}\n\n" @app.post("/chat/completions") -async def chat_completions(request: ChatRequest): +async def chat_completions(request: Request, chat_request: ChatRequest, _: None = Depends(check_rate_limit)): """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU.""" # Basic validation to prevent abuse - validate_chat_request(request) + validate_chat_request(chat_request) - # Log incoming conversation to console + # Log incoming conversation to console (truncated for privacy) logger.info("="*20) - for i, message in enumerate(request.messages): - logger.info(f"[{message.role.upper()}]: {message.content}") + for i, message in enumerate(chat_request.messages): + preview = message.content[:120] + ("…" if len(message.content) > 120 else "") + logger.debug(f"[{message.role.upper()}]: {preview}") + logger.info(f"Received request with {len(chat_request.messages)} message(s)") logger.info("-"*20) # Acquire a worker from the pool (will wait if all are busy) @@ -328,7 +369,7 @@ async def chat_completions(request: ChatRequest): assistant_end = worker.tokenizer.encode_special("<|assistant_end|>") conversation_tokens = [bos] - for message in request.messages: + for message in chat_request.messages: if message.role == "user": conversation_tokens.append(user_start) conversation_tokens.extend(worker.tokenizer.encode(message.content)) @@ -347,9 +388,9 @@ async def chat_completions(request: ChatRequest): async for chunk in generate_stream( worker, conversation_tokens, - temperature=request.temperature, - max_new_tokens=request.max_tokens, - top_k=request.top_k + temperature=chat_request.temperature, + max_new_tokens=chat_request.max_tokens, + top_k=chat_request.top_k ): # Accumulate response for logging chunk_data = json.loads(chunk.replace("data: ", "").strip()) @@ -357,9 +398,9 @@ async def chat_completions(request: ChatRequest): response_tokens.append(chunk_data["token"]) yield chunk finally: - # Log the assistant response to console + # Log response length only (not content) to avoid storing user data in logs full_response = "".join(response_tokens) - logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {full_response}") + logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {len(full_response)} chars generated") logger.info("="*20) # Release worker back to pool after streaming is done await worker_pool.release_worker(worker) @@ -374,8 +415,8 @@ async def chat_completions(request: ChatRequest): raise e @app.get("/health") -async def health(): - """Health check endpoint.""" +async def health(_: None = Depends(check_stats_key)): + """Health check endpoint. Protected by X-Stats-Key header when NANOCHAT_STATS_KEY is set.""" worker_pool = getattr(app.state, 'worker_pool', None) return { "status": "ok", @@ -385,8 +426,8 @@ async def health(): } @app.get("/stats") -async def stats(): - """Get worker pool statistics.""" +async def stats(_: None = Depends(check_stats_key)): + """Get worker pool statistics. Protected by X-Stats-Key header when NANOCHAT_STATS_KEY is set.""" worker_pool = app.state.worker_pool return { "total_workers": len(worker_pool.workers),