From 3fa394c93f4aa5b53e9bc1de7b00cea55f518f12 Mon Sep 17 00:00:00 2001 From: santhoshravindran7 Date: Sun, 8 Mar 2026 23:12:50 -0700 Subject: [PATCH] security: fix unsafe deserialization, XSS, HTTPS enforcement, and temp file race MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Five targeted security fixes — all non-breaking, no behaviour change on the happy path. H-1 (High) — nanochat/checkpoint_manager.py Add weights_only=True to all three torch.load() calls. torch.load() uses pickle by default; loading a malicious .pt file from an untrusted source allows arbitrary code execution. weights_only=True restricts deserialization to tensors and primitives, blocking this attack surface. Refs: https://pytorch.org/docs/stable/generated/torch.load.html H-3 (High) — nanochat/ui.html Replace innerHTML injection with createElement + textContent for error display. error.message was interpolated directly into innerHTML, creating an XSS sink: a crafted server error response could inject and execute arbitrary JavaScript. textContent escapes all HTML entities, closing the injection path. L-1 (Low) — scripts/chat_web.py Fix misleading role validation error message. The error string claimed 'system' was a valid role, but the guard only accepts 'user' and 'assistant'. Corrected to reflect the actual allowed values. M-3 (Medium) — nanochat/common.py Reject non-HTTPS URLs in download_file_with_lock(). urlopen() follows redirects including HTTPS->HTTP downgrades, enabling MITM attacks on downloaded model/tokenizer files. Added an explicit scheme check that raises ValueError for any non-HTTPS URL before the request is made. L-3 (Low) — nanochat/dataset.py Replace predictable .tmp suffix with tempfile.NamedTemporaryFile. The previous filepath + '.tmp' naming caused a TOCTOU race when multiple worker processes downloaded the same shard concurrently, and is vulnerable to symlink attacks on shared filesystems. NamedTemporaryFile generates a unique path; os.replace() provides an atomic rename on POSIX. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- nanochat/checkpoint_manager.py | 6 +++--- nanochat/common.py | 6 ++++++ nanochat/dataset.py | 17 ++++++++++------- nanochat/ui.html | 6 +++++- scripts/chat_web.py | 2 +- 5 files changed, 25 insertions(+), 12 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index f71524e..407ad15 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -61,12 +61,12 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0): # Load the model state model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") - model_data = torch.load(model_path, map_location=device) + model_data = torch.load(model_path, map_location=device, weights_only=True) # Load the optimizer state if requested optimizer_data = None if load_optimizer: optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") - optimizer_data = torch.load(optimizer_path, map_location=device) + optimizer_data = torch.load(optimizer_path, map_location=device, weights_only=True) # Load the metadata meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") with open(meta_path, "r", encoding="utf-8") as f: @@ -190,5 +190,5 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None): log0(f"Optimizer checkpoint not found: {optimizer_path}") return None log0(f"Loading optimizer state from {optimizer_path}") - optimizer_data = torch.load(optimizer_path, map_location=device) + optimizer_data = torch.load(optimizer_path, map_location=device, weights_only=True) return optimizer_data diff --git a/nanochat/common.py b/nanochat/common.py index bd14fd2..0ab1a27 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -82,7 +82,13 @@ def download_file_with_lock(url, filename, postprocess_fn=None): """ Downloads a file from a URL to a local path in the base directory. Uses a lock file to prevent concurrent downloads among multiple ranks. + Only HTTPS URLs are accepted to prevent unencrypted transfers. """ + from urllib.parse import urlparse + parsed = urlparse(url) + if parsed.scheme != "https": + raise ValueError(f"Only HTTPS URLs are allowed for downloads, got: {url!r}") + base_dir = get_base_dir() file_path = os.path.join(base_dir, filename) lock_path = file_path + ".lock" diff --git a/nanochat/dataset.py b/nanochat/dataset.py index fffe722..034a448 100644 --- a/nanochat/dataset.py +++ b/nanochat/dataset.py @@ -10,6 +10,7 @@ For details of how the dataset was prepared, see `repackage_data_reference.py`. import os import argparse import time +import tempfile import requests import pyarrow.parquet as pq from multiprocessing import Pool @@ -101,21 +102,23 @@ def download_single_file(index): try: response = requests.get(url, stream=True, timeout=30) response.raise_for_status() - # Write to temporary file first - temp_path = filepath + f".tmp" - with open(temp_path, 'wb') as f: + # Write to a unique temporary file to avoid races between concurrent workers + with tempfile.NamedTemporaryFile( + dir=os.path.dirname(filepath), delete=False, suffix=".tmp" + ) as tmp_f: + temp_path = tmp_f.name for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks if chunk: - f.write(chunk) - # Move temp file to final location - os.rename(temp_path, filepath) + tmp_f.write(chunk) + # Atomically move temp file to final location + os.replace(temp_path, filepath) print(f"Successfully downloaded {filename}") return True except (requests.RequestException, IOError) as e: print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}") # Clean up any partial files - for path in [filepath + f".tmp", filepath]: + for path in [temp_path if 'temp_path' in dir() else filepath + ".tmp", filepath]: if os.path.exists(path): try: os.remove(path) diff --git a/nanochat/ui.html b/nanochat/ui.html index b85532a..536b99b 100644 --- a/nanochat/ui.html +++ b/nanochat/ui.html @@ -449,7 +449,11 @@ } catch (error) { console.error('Error:', error); - assistantContent.innerHTML = `
Error: ${error.message}
`; + const errDiv = document.createElement('div'); + errDiv.className = 'error-message'; + errDiv.textContent = `Error: ${error.message}`; + assistantContent.innerHTML = ''; + assistantContent.appendChild(errDiv); } finally { isGenerating = false; sendButton.disabled = !chatInput.value.trim(); diff --git a/scripts/chat_web.py b/scripts/chat_web.py index ffaf7da..5e9763a 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -186,7 +186,7 @@ def validate_chat_request(request: ChatRequest): if message.role not in ["user", "assistant"]: raise HTTPException( status_code=400, - detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'" + detail=f"Message {i} has invalid role '{message.role}'. Must be 'user' or 'assistant'." ) # Validate temperature