mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 05:35:19 +00:00
security: fix unsafe deserialization, XSS, HTTPS enforcement, and temp file race
Five targeted security fixes — all non-breaking, no behaviour change on the happy path. H-1 (High) — nanochat/checkpoint_manager.py Add weights_only=True to all three torch.load() calls. torch.load() uses pickle by default; loading a malicious .pt file from an untrusted source allows arbitrary code execution. weights_only=True restricts deserialization to tensors and primitives, blocking this attack surface. Refs: https://pytorch.org/docs/stable/generated/torch.load.html H-3 (High) — nanochat/ui.html Replace innerHTML injection with createElement + textContent for error display. error.message was interpolated directly into innerHTML, creating an XSS sink: a crafted server error response could inject and execute arbitrary JavaScript. textContent escapes all HTML entities, closing the injection path. L-1 (Low) — scripts/chat_web.py Fix misleading role validation error message. The error string claimed 'system' was a valid role, but the guard only accepts 'user' and 'assistant'. Corrected to reflect the actual allowed values. M-3 (Medium) — nanochat/common.py Reject non-HTTPS URLs in download_file_with_lock(). urlopen() follows redirects including HTTPS->HTTP downgrades, enabling MITM attacks on downloaded model/tokenizer files. Added an explicit scheme check that raises ValueError for any non-HTTPS URL before the request is made. L-3 (Low) — nanochat/dataset.py Replace predictable .tmp suffix with tempfile.NamedTemporaryFile. The previous filepath + '.tmp' naming caused a TOCTOU race when multiple worker processes downloaded the same shard concurrently, and is vulnerable to symlink attacks on shared filesystems. NamedTemporaryFile generates a unique path; os.replace() provides an atomic rename on POSIX. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
parent
1076f97059
commit
3fa394c93f
|
|
@ -61,12 +61,12 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data,
|
|||
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
|
||||
# Load the model state
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||
model_data = torch.load(model_path, map_location=device)
|
||||
model_data = torch.load(model_path, map_location=device, weights_only=True)
|
||||
# Load the optimizer state if requested
|
||||
optimizer_data = None
|
||||
if load_optimizer:
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device, weights_only=True)
|
||||
# Load the metadata
|
||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||
with open(meta_path, "r", encoding="utf-8") as f:
|
||||
|
|
@ -190,5 +190,5 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None):
|
|||
log0(f"Optimizer checkpoint not found: {optimizer_path}")
|
||||
return None
|
||||
log0(f"Loading optimizer state from {optimizer_path}")
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device, weights_only=True)
|
||||
return optimizer_data
|
||||
|
|
|
|||
|
|
@ -82,7 +82,13 @@ def download_file_with_lock(url, filename, postprocess_fn=None):
|
|||
"""
|
||||
Downloads a file from a URL to a local path in the base directory.
|
||||
Uses a lock file to prevent concurrent downloads among multiple ranks.
|
||||
Only HTTPS URLs are accepted to prevent unencrypted transfers.
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme != "https":
|
||||
raise ValueError(f"Only HTTPS URLs are allowed for downloads, got: {url!r}")
|
||||
|
||||
base_dir = get_base_dir()
|
||||
file_path = os.path.join(base_dir, filename)
|
||||
lock_path = file_path + ".lock"
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ For details of how the dataset was prepared, see `repackage_data_reference.py`.
|
|||
import os
|
||||
import argparse
|
||||
import time
|
||||
import tempfile
|
||||
import requests
|
||||
import pyarrow.parquet as pq
|
||||
from multiprocessing import Pool
|
||||
|
|
@ -101,21 +102,23 @@ def download_single_file(index):
|
|||
try:
|
||||
response = requests.get(url, stream=True, timeout=30)
|
||||
response.raise_for_status()
|
||||
# Write to temporary file first
|
||||
temp_path = filepath + f".tmp"
|
||||
with open(temp_path, 'wb') as f:
|
||||
# Write to a unique temporary file to avoid races between concurrent workers
|
||||
with tempfile.NamedTemporaryFile(
|
||||
dir=os.path.dirname(filepath), delete=False, suffix=".tmp"
|
||||
) as tmp_f:
|
||||
temp_path = tmp_f.name
|
||||
for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
# Move temp file to final location
|
||||
os.rename(temp_path, filepath)
|
||||
tmp_f.write(chunk)
|
||||
# Atomically move temp file to final location
|
||||
os.replace(temp_path, filepath)
|
||||
print(f"Successfully downloaded {filename}")
|
||||
return True
|
||||
|
||||
except (requests.RequestException, IOError) as e:
|
||||
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
|
||||
# Clean up any partial files
|
||||
for path in [filepath + f".tmp", filepath]:
|
||||
for path in [temp_path if 'temp_path' in dir() else filepath + ".tmp", filepath]:
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
os.remove(path)
|
||||
|
|
|
|||
|
|
@ -449,7 +449,11 @@
|
|||
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
|
||||
const errDiv = document.createElement('div');
|
||||
errDiv.className = 'error-message';
|
||||
errDiv.textContent = `Error: ${error.message}`;
|
||||
assistantContent.innerHTML = '';
|
||||
assistantContent.appendChild(errDiv);
|
||||
} finally {
|
||||
isGenerating = false;
|
||||
sendButton.disabled = !chatInput.value.trim();
|
||||
|
|
|
|||
|
|
@ -186,7 +186,7 @@ def validate_chat_request(request: ChatRequest):
|
|||
if message.role not in ["user", "assistant"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
|
||||
detail=f"Message {i} has invalid role '{message.role}'. Must be 'user' or 'assistant'."
|
||||
)
|
||||
|
||||
# Validate temperature
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user