security: fix unsafe deserialization, XSS, HTTPS enforcement, and temp file race

Five targeted security fixes — all non-breaking, no behaviour change on the happy path.

H-1 (High) — nanochat/checkpoint_manager.py
  Add weights_only=True to all three torch.load() calls.
  torch.load() uses pickle by default; loading a malicious .pt file from an
  untrusted source allows arbitrary code execution. weights_only=True restricts
  deserialization to tensors and primitives, blocking this attack surface.
  Refs: https://pytorch.org/docs/stable/generated/torch.load.html

H-3 (High) — nanochat/ui.html
  Replace innerHTML injection with createElement + textContent for error display.
  error.message was interpolated directly into innerHTML, creating an XSS sink:
  a crafted server error response could inject and execute arbitrary JavaScript.
  textContent escapes all HTML entities, closing the injection path.

L-1 (Low) — scripts/chat_web.py
  Fix misleading role validation error message.
  The error string claimed 'system' was a valid role, but the guard only accepts
  'user' and 'assistant'. Corrected to reflect the actual allowed values.

M-3 (Medium) — nanochat/common.py
  Reject non-HTTPS URLs in download_file_with_lock().
  urlopen() follows redirects including HTTPS->HTTP downgrades, enabling MITM
  attacks on downloaded model/tokenizer files. Added an explicit scheme check
  that raises ValueError for any non-HTTPS URL before the request is made.

L-3 (Low) — nanochat/dataset.py
  Replace predictable .tmp suffix with tempfile.NamedTemporaryFile.
  The previous filepath + '.tmp' naming caused a TOCTOU race when multiple
  worker processes downloaded the same shard concurrently, and is vulnerable
  to symlink attacks on shared filesystems. NamedTemporaryFile generates a
  unique path; os.replace() provides an atomic rename on POSIX.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
santhoshravindran7 2026-03-08 23:12:50 -07:00
parent 1076f97059
commit 3fa394c93f
5 changed files with 25 additions and 12 deletions

View File

@ -61,12 +61,12 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data,
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
# Load the model state
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
model_data = torch.load(model_path, map_location=device)
model_data = torch.load(model_path, map_location=device, weights_only=True)
# Load the optimizer state if requested
optimizer_data = None
if load_optimizer:
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
optimizer_data = torch.load(optimizer_path, map_location=device)
optimizer_data = torch.load(optimizer_path, map_location=device, weights_only=True)
# Load the metadata
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
with open(meta_path, "r", encoding="utf-8") as f:
@ -190,5 +190,5 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None):
log0(f"Optimizer checkpoint not found: {optimizer_path}")
return None
log0(f"Loading optimizer state from {optimizer_path}")
optimizer_data = torch.load(optimizer_path, map_location=device)
optimizer_data = torch.load(optimizer_path, map_location=device, weights_only=True)
return optimizer_data

View File

@ -82,7 +82,13 @@ def download_file_with_lock(url, filename, postprocess_fn=None):
"""
Downloads a file from a URL to a local path in the base directory.
Uses a lock file to prevent concurrent downloads among multiple ranks.
Only HTTPS URLs are accepted to prevent unencrypted transfers.
"""
from urllib.parse import urlparse
parsed = urlparse(url)
if parsed.scheme != "https":
raise ValueError(f"Only HTTPS URLs are allowed for downloads, got: {url!r}")
base_dir = get_base_dir()
file_path = os.path.join(base_dir, filename)
lock_path = file_path + ".lock"

View File

@ -10,6 +10,7 @@ For details of how the dataset was prepared, see `repackage_data_reference.py`.
import os
import argparse
import time
import tempfile
import requests
import pyarrow.parquet as pq
from multiprocessing import Pool
@ -101,21 +102,23 @@ def download_single_file(index):
try:
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
# Write to temporary file first
temp_path = filepath + f".tmp"
with open(temp_path, 'wb') as f:
# Write to a unique temporary file to avoid races between concurrent workers
with tempfile.NamedTemporaryFile(
dir=os.path.dirname(filepath), delete=False, suffix=".tmp"
) as tmp_f:
temp_path = tmp_f.name
for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
if chunk:
f.write(chunk)
# Move temp file to final location
os.rename(temp_path, filepath)
tmp_f.write(chunk)
# Atomically move temp file to final location
os.replace(temp_path, filepath)
print(f"Successfully downloaded {filename}")
return True
except (requests.RequestException, IOError) as e:
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
# Clean up any partial files
for path in [filepath + f".tmp", filepath]:
for path in [temp_path if 'temp_path' in dir() else filepath + ".tmp", filepath]:
if os.path.exists(path):
try:
os.remove(path)

View File

@ -449,7 +449,11 @@
} catch (error) {
console.error('Error:', error);
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
const errDiv = document.createElement('div');
errDiv.className = 'error-message';
errDiv.textContent = `Error: ${error.message}`;
assistantContent.innerHTML = '';
assistantContent.appendChild(errDiv);
} finally {
isGenerating = false;
sendButton.disabled = !chatInput.value.trim();

View File

@ -186,7 +186,7 @@ def validate_chat_request(request: ChatRequest):
if message.role not in ["user", "assistant"]:
raise HTTPException(
status_code=400,
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
detail=f"Message {i} has invalid role '{message.role}'. Must be 'user' or 'assistant'."
)
# Validate temperature