From fe427a03c3f06abc587fcf1d5e42bcd5596ab3cd Mon Sep 17 00:00:00 2001 From: Matthias Humt Date: Sun, 8 Feb 2026 20:04:11 +0100 Subject: [PATCH 1/2] fix: harden deserialization, path traversal, log redaction, and size limits Address all 5 security findings from issue #55: 1. [CRITICAL] Restricted pickle deserialization via _RestrictedUnpickler (blocks arbitrary code execution from crafted index files) 2. [HIGH] Path traversal validation on index_name at all entry points 3. [MEDIUM] Staging boundary check before file deletion in crash handler 4. [MEDIUM] API key redaction in LEANN debug/error log output 5. [LOW] Bounded zlib decompression (100MB), download size (500MB), and tarball extraction limits (100MB/file, 500MB total) Includes 25 new tests covering all fixes. Co-Authored-By: Claude Opus 4.6 --- paperpipe/cli/ask.py | 4 +- paperpipe/leann.py | 33 ++++++-- paperpipe/paper.py | 35 ++++++++- paperpipe/paperqa.py | 108 ++++++++++++++++++++++--- paperpipe/paperqa_mcp_server.py | 83 ++++++++++++++++++-- tests/test_figure_extraction.py | 2 + tests/test_leann.py | 36 ++++++++- tests/test_paper.py | 45 +++++++++++ tests/test_paperqa.py | 135 ++++++++++++++++++++++++++++++++ 9 files changed, 455 insertions(+), 26 deletions(-) diff --git a/paperpipe/cli/ask.py b/paperpipe/cli/ask.py index 25c14fb..380d738 100644 --- a/paperpipe/cli/ask.py +++ b/paperpipe/cli/ask.py @@ -729,7 +729,9 @@ def ask( # Only remove files from paperpipe's managed staging directory. # Never delete from a user-provided paper directory. managed_staging_dir = (config.PAPER_DB / ".pqa_papers").expanduser() - if paper_dir.resolve() == managed_staging_dir.resolve(): + if paper_dir.resolve() == managed_staging_dir.resolve() and f.resolve().is_relative_to( + managed_staging_dir.resolve() + ): try: f.unlink() echo_warning(f"Removed '{crashing_doc}' from PaperQA2 staging to prevent re-indexing.") diff --git a/paperpipe/leann.py b/paperpipe/leann.py index f9bdef3..beccd29 100644 --- a/paperpipe/leann.py +++ b/paperpipe/leann.py @@ -25,6 +25,7 @@ default_leann_llm_provider, ) from .output import debug, echo_error, echo_progress, echo_warning +from .paperqa import _validate_index_name # ----------------------------------------------------------------------------- # Manifest types and I/O for incremental indexing @@ -69,7 +70,7 @@ class IndexDelta: def _leann_manifest_path(index_name: str) -> Path: """Path to paperpipe's incremental indexing manifest.""" - return config.PAPER_DB / ".leann" / "indexes" / index_name / "paperpipe_manifest.json" + return config.PAPER_DB / ".leann" / "indexes" / _validate_index_name(index_name) / "paperpipe_manifest.json" def _load_leann_manifest(index_name: str) -> Optional[LeannManifest]: @@ -250,7 +251,7 @@ def _leann_incremental_update( except ImportError as e: raise IncrementalUpdateError(f"LEANN Python API not available: {e}") from e - index_dir = config.PAPER_DB / ".leann" / "indexes" / index_name + index_dir = config.PAPER_DB / ".leann" / "indexes" / _validate_index_name(index_name) index_path = index_dir / "documents.leann" index_file = index_dir / f"{index_path.stem}.index" @@ -436,6 +437,26 @@ def _leann_incremental_update( return added, delta.unchanged_count, errors +_REDACT_FLAGS = {"--api-key", "--embedding-api-key"} + + +def _redact_cmd(cmd: list[str]) -> str: + """Return a shell-safe string representation of cmd with API keys redacted.""" + redacted = list(cmd) + i = 0 + while i < len(redacted): + if redacted[i] in _REDACT_FLAGS and i + 1 < len(redacted): + redacted[i + 1] = "***" + i += 2 + elif any(redacted[i].startswith(f"{flag}=") for flag in _REDACT_FLAGS): + flag = redacted[i].split("=", 1)[0] + redacted[i] = f"{flag}=***" + i += 1 + else: + i += 1 + return shlex.join(redacted) + + def _extract_arg_value(args: list[str], flag: str) -> Optional[str]: """Extract the value for a CLI flag from args list.""" for i, arg in enumerate(args): @@ -452,7 +473,7 @@ def _extract_arg_value(args: list[str], flag: str) -> Optional[str]: def _leann_index_meta_path(index_name: str) -> Path: - return config.PAPER_DB / ".leann" / "indexes" / index_name / "documents.leann.meta.json" + return config.PAPER_DB / ".leann" / "indexes" / _validate_index_name(index_name) / "documents.leann.meta.json" def _load_leann_backend_name(index_name: str) -> Optional[str]: @@ -598,11 +619,11 @@ def _leann_build_index( embedding_mode_for_meta = embedding_mode_override or default_leann_embedding_mode() cmd.extend(extra_args) - debug("Running LEANN: %s", shlex.join(cmd)) + debug("Running LEANN: %s", _redact_cmd(cmd)) proc = subprocess.run(cmd, cwd=config.PAPER_DB) if proc.returncode != 0: echo_error(f"LEANN command failed (exit code {proc.returncode})") - echo_error(f"Command: {shlex.join(cmd)}") + echo_error(f"Command: {_redact_cmd(cmd)}") raise SystemExit(proc.returncode) # Write metadata on success @@ -704,7 +725,7 @@ def _ask_leann( cmd.extend(["--thinking-budget", thinking_budget]) cmd.extend(extra_args) - debug("Running LEANN: %s", shlex.join(cmd)) + debug("Running LEANN: %s", _redact_cmd(cmd)) if interactive: proc = subprocess.run(cmd, cwd=config.PAPER_DB) diff --git a/paperpipe/paper.py b/paperpipe/paper.py index cfd681b..e38dcd8 100644 --- a/paperpipe/paper.py +++ b/paperpipe/paper.py @@ -41,6 +41,10 @@ from .output import debug, echo_error, echo_progress, echo_success, echo_warning from .search import _maybe_update_search_index +_MAX_DOWNLOAD_SIZE = 500 * 1024 * 1024 # 500 MB +_MAX_TAR_MEMBER_SIZE = 100 * 1024 * 1024 # 100 MB +_MAX_TAR_TOTAL_SIZE = 500 * 1024 * 1024 # 500 MB + def fetch_arxiv_metadata(arxiv_id: str) -> dict: """Fetch paper metadata from arXiv API.""" @@ -136,6 +140,15 @@ def download_source(arxiv_id: str, paper_dir: Path, *, extract_figures: bool = F echo_warning(f"Could not download source for {arxiv_id}: {e}") return None + # Check download size + content_length = response.headers.get("Content-Length") + if content_length and int(content_length) > _MAX_DOWNLOAD_SIZE: + echo_warning(f"Source archive for {arxiv_id} too large ({int(content_length)} bytes). Skipping.") + return None + if len(response.content) > _MAX_DOWNLOAD_SIZE: + echo_warning(f"Source archive for {arxiv_id} too large ({len(response.content)} bytes). Skipping.") + return None + # Save and extract tarball with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as f: f.write(response.content) @@ -145,14 +158,21 @@ def download_source(arxiv_id: str, paper_dir: Path, *, extract_figures: bool = F try: # Try to open as tar (most common) with tarfile.open(tar_path) as tar: - tex_members = [m for m in tar.getmembers() if m.isfile() and m.name.endswith(".tex")] + tex_members = [ + m for m in tar.getmembers() if m.isfile() and m.name.endswith(".tex") and m.size <= _MAX_TAR_MEMBER_SIZE + ] if tex_members: tex_by_name: dict[str, str] = {} + total_extracted = 0 for member in tex_members: + if total_extracted + member.size > _MAX_TAR_TOTAL_SIZE: + break extracted = tar.extractfile(member) if not extracted: continue - tex_by_name[member.name] = extracted.read().decode("utf-8", errors="replace") + content = extracted.read() + total_extracted += len(content) + tex_by_name[member.name] = content.decode("utf-8", errors="replace") preferred_names = ("main.tex", "paper.tex") preferred = [n for n in tex_by_name if Path(n).name in preferred_names] @@ -288,6 +308,7 @@ def _extract_figures_from_latex(tex_content: str, tar: tarfile.TarFile, paper_di figures_dir = paper_dir / "figures" extracted_count = 0 + total_extracted_size = 0 for ref in matches: # LaTeX allows omitting file extension, so we need to try with and without @@ -311,10 +332,16 @@ def _extract_figures_from_latex(tex_content: str, tar: tarfile.TarFile, paper_di for member in tar.getmembers(): if not member.isfile(): continue + if member.size > _MAX_TAR_MEMBER_SIZE: + continue member_name = member.name.replace("\\", "/") # Check if member path ends with candidate or matches exactly if member_name == candidate_path or member_name.endswith("/" + candidate_path): + if total_extracted_size + member.size > _MAX_TAR_TOTAL_SIZE: + echo_warning(" Tar extraction total size limit reached") + return extracted_count + # Extract the file if not figures_dir.exists(): try: @@ -331,7 +358,9 @@ def _extract_figures_from_latex(tex_content: str, tar: tarfile.TarFile, paper_di dest_path = figures_dir / dest_name # Write file - dest_path.write_bytes(file_obj.read()) + data = file_obj.read() + total_extracted_size += len(data) + dest_path.write_bytes(data) extracted_count += 1 found = True break diff --git a/paperpipe/paperqa.py b/paperpipe/paperqa.py index c428767..540a92a 100644 --- a/paperpipe/paperqa.py +++ b/paperpipe/paperqa.py @@ -2,6 +2,7 @@ from __future__ import annotations +import io import os import pickle import re @@ -23,6 +24,79 @@ def _pillow_available() -> bool: return importlib.util.find_spec("PIL") is not None +_SAFE_PICKLE_BUILTINS = frozenset({"dict", "str", "list", "tuple", "set", "frozenset", "int", "float", "bool", "bytes"}) + + +class _RestrictedUnpickler(pickle.Unpickler): + """Unpickler that only allows basic Python builtins (no arbitrary code execution).""" + + def find_class(self, module: str, name: str) -> Any: + if module == "builtins" and name in _SAFE_PICKLE_BUILTINS: + return ( + getattr(__builtins__ if isinstance(__builtins__, dict) else __builtins__, name, None) + or { + "dict": dict, + "str": str, + "list": list, + "tuple": tuple, + "set": set, + "frozenset": frozenset, + "int": int, + "float": float, + "bool": bool, + "bytes": bytes, + }[name] + ) + raise pickle.UnpicklingError(f"Restricted: {module}.{name}") + + +def _safe_unpickle(data: bytes) -> Any: + """Deserialize pickle data using a restricted unpickler (blocks arbitrary code execution).""" + return _RestrictedUnpickler(io.BytesIO(data)).load() + + +_MAX_ZLIB_DECOMPRESS_SIZE = 100 * 1024 * 1024 # 100 MB + + +def _safe_zlib_decompress(data: bytes, *, max_size: int = _MAX_ZLIB_DECOMPRESS_SIZE) -> bytes: + """Decompress zlib data with a size limit to prevent decompression bombs.""" + dobj = zlib.decompressobj() + chunks: list[bytes] = [] + total = 0 + # Feed data in 64 KB blocks to limit peak memory before aborting + block_size = 65536 + for i in range(0, len(data), block_size): + chunk = dobj.decompress(data[i : i + block_size]) + total += len(chunk) + if total > max_size: + raise ValueError(f"Decompressed data exceeds {max_size} byte limit") + chunks.append(chunk) + # Flush remaining data + remaining = dobj.flush() + total += len(remaining) + if total > max_size: + raise ValueError(f"Decompressed data exceeds {max_size} byte limit") + chunks.append(remaining) + return b"".join(chunks) + + +def _validate_index_name(index_name: str) -> str: + """Validate and return a safe index name (rejects path traversal attempts).""" + name = (index_name or "").strip() + if not name: + raise ValueError("index_name must be non-empty") + if "\x00" in name: + raise ValueError("index_name must not contain null bytes") + if "/" in name or "\\" in name: + raise ValueError("index_name must not contain path separators") + if ".." in name: + raise ValueError("index_name must not contain '..' components") + # Also reject pure dots + if all(c == "." for c in name): + raise ValueError("index_name must not be only dots") + return name + + def _refresh_pqa_pdf_staging_dir(*, staging_dir: Path, exclude_names: Optional[set[str]] = None) -> int: """ Create/update a flat directory containing only PDFs (one per paper) for PaperQA2 indexing. @@ -138,17 +212,29 @@ def _paperqa_find_crashing_file(*, paper_directory: Path, crashing_doc: str) -> if not doc: return None + resolved_parent = paper_directory.resolve() + + def _within_parent(p: Path) -> Optional[Path]: + """Return resolved path if it's within paper_directory, else None.""" + try: + rp = p.resolve() + if rp.is_relative_to(resolved_parent) and rp.exists(): + return rp + except (OSError, ValueError): + pass + return None + doc_path = Path(doc) if doc_path.is_absolute(): - return doc_path if doc_path.exists() else None + return _within_parent(doc_path) if ".." in doc_path.parts: doc_path = Path(doc_path.name) # Try the path as-is (relative to the paper directory). - candidate = paper_directory / doc_path - if candidate.exists(): - return candidate + result = _within_parent(paper_directory / doc_path) + if result is not None: + return result # Try matching by file name/stem (common when pqa prints just "foo.pdf" or "foo"). name = doc_path.name @@ -159,7 +245,9 @@ def _paperqa_find_crashing_file(*, paper_directory: Path, crashing_doc: str) -> try: for f in paper_directory.iterdir(): if f.name == name or f.stem == expected_stem: - return f + checked = _within_parent(f) + if checked is not None: + return checked except OSError: pass @@ -167,7 +255,9 @@ def _paperqa_find_crashing_file(*, paper_directory: Path, crashing_doc: str) -> try: for f in paper_directory.rglob(name): if f.name == name: - return f + checked = _within_parent(f) + if checked is not None: + return checked except OSError: pass @@ -175,13 +265,13 @@ def _paperqa_find_crashing_file(*, paper_directory: Path, crashing_doc: str) -> def _paperqa_index_files_path(*, index_directory: Path, index_name: str) -> Path: - return Path(index_directory) / index_name / "files.zip" + return Path(index_directory) / _validate_index_name(index_name) / "files.zip" def _paperqa_load_index_files_map(path: Path) -> Optional[dict[str, str]]: try: - raw = zlib.decompress(path.read_bytes()) - obj = pickle.loads(raw) + raw = _safe_zlib_decompress(path.read_bytes()) + obj = _safe_unpickle(raw) except Exception: return None if not isinstance(obj, dict): diff --git a/paperpipe/paperqa_mcp_server.py b/paperpipe/paperqa_mcp_server.py index 14c63c2..174faf0 100644 --- a/paperpipe/paperqa_mcp_server.py +++ b/paperpipe/paperqa_mcp_server.py @@ -30,9 +30,11 @@ from __future__ import annotations +import io import json import logging import os +import pickle import sys import zlib from pathlib import Path @@ -122,17 +124,85 @@ def _default_index_name(embedding_model: str) -> str: return _index_name_for_embedding(embedding_model) +_SAFE_PICKLE_BUILTINS = frozenset({"dict", "str", "list", "tuple", "set", "frozenset", "int", "float", "bool", "bytes"}) + + +class _RestrictedUnpickler(pickle.Unpickler): + """Unpickler that only allows basic Python builtins (no arbitrary code execution).""" + + def find_class(self, module: str, name: str) -> Any: + if module == "builtins" and name in _SAFE_PICKLE_BUILTINS: + return ( + getattr(__builtins__ if isinstance(__builtins__, dict) else __builtins__, name, None) + or { + "dict": dict, + "str": str, + "list": list, + "tuple": tuple, + "set": set, + "frozenset": frozenset, + "int": int, + "float": float, + "bool": bool, + "bytes": bytes, + }[name] + ) + raise pickle.UnpicklingError(f"Restricted: {module}.{name}") + + +def _safe_unpickle(data: bytes) -> Any: + """Deserialize pickle data using a restricted unpickler (blocks arbitrary code execution).""" + return _RestrictedUnpickler(io.BytesIO(data)).load() + + +_MAX_ZLIB_DECOMPRESS_SIZE = 100 * 1024 * 1024 # 100 MB + + +def _safe_zlib_decompress(data: bytes, *, max_size: int = _MAX_ZLIB_DECOMPRESS_SIZE) -> bytes: + """Decompress zlib data with a size limit to prevent decompression bombs.""" + dobj = zlib.decompressobj() + chunks: list[bytes] = [] + total = 0 + block_size = 65536 + for i in range(0, len(data), block_size): + chunk = dobj.decompress(data[i : i + block_size]) + total += len(chunk) + if total > max_size: + raise ValueError(f"Decompressed data exceeds {max_size} byte limit") + chunks.append(chunk) + remaining = dobj.flush() + total += len(remaining) + if total > max_size: + raise ValueError(f"Decompressed data exceeds {max_size} byte limit") + chunks.append(remaining) + return b"".join(chunks) + + +def _validate_index_name(index_name: str) -> str: + """Validate and return a safe index name (rejects path traversal attempts).""" + name = (index_name or "").strip() + if not name: + raise ValueError("index_name must be non-empty") + if "\x00" in name: + raise ValueError("index_name must not contain null bytes") + if "/" in name or "\\" in name: + raise ValueError("index_name must not contain path separators") + if ".." in name: + raise ValueError("index_name must not contain '..' components") + if all(c == "." for c in name): + raise ValueError("index_name must not be only dots") + return name + + def _index_meta_exists(index_root: Path, index_name: str) -> bool: - return (index_root / index_name / "index" / "meta.json").exists() + return (index_root / _validate_index_name(index_name) / "index" / "meta.json").exists() def _load_files_zip_map(files_zip: Path) -> dict[str, str] | None: """Load PaperQA2's index file map (zlib-compressed pickle).""" try: - import pickle - - raw = zlib.decompress(files_zip.read_bytes()) - obj = pickle.loads(raw) # noqa: S301 - PaperQA2 format + raw = _safe_zlib_decompress(files_zip.read_bytes()) + obj = _safe_unpickle(raw) except Exception: return None if not isinstance(obj, dict): @@ -150,7 +220,7 @@ def _load_embedding_from_metadata(index_root: Path, index_name: str) -> str | No Returns None if file doesn't exist or is invalid. """ - meta_path = index_root / index_name / "paperpipe_meta.json" + meta_path = index_root / _validate_index_name(index_name) / "paperpipe_meta.json" if not meta_path.exists(): return None try: @@ -171,6 +241,7 @@ def _write_index_metadata(index_root: Path, index_name: str, embedding_model: st Creates paperpipe_meta.json with embedding model, timestamp, and version. Callers should wrap in try/except to ensure index operations succeed regardless of metadata write failures. """ + _validate_index_name(index_name) try: from importlib.metadata import PackageNotFoundError from importlib.metadata import version as get_version diff --git a/tests/test_figure_extraction.py b/tests/test_figure_extraction.py index 4541223..b26d1e9 100644 --- a/tests/test_figure_extraction.py +++ b/tests/test_figure_extraction.py @@ -381,6 +381,7 @@ class MockResponse: def __init__(self): self.content = tar_bytes self.status_code = 200 + self.headers: dict[str, str] = {} def raise_for_status(self): pass @@ -420,6 +421,7 @@ class MockResponse: def __init__(self): self.content = tar_bytes self.status_code = 200 + self.headers: dict[str, str] = {} def raise_for_status(self): pass diff --git a/tests/test_leann.py b/tests/test_leann.py index 1750d6f..27b221f 100644 --- a/tests/test_leann.py +++ b/tests/test_leann.py @@ -14,7 +14,7 @@ import paperpipe import paperpipe.config as config import paperpipe.paperqa as paperqa -from paperpipe.leann import FileEntry, LeannManifest +from paperpipe.leann import FileEntry, LeannManifest, _redact_cmd # Import the CLI module explicitly (avoid resolving to the package's cli function). cli_mod = import_module("paperpipe.cli") @@ -722,3 +722,37 @@ def update_index(self, path): assert added == 0 assert unchanged == 1 assert errors == 0 + + +class TestRedactCmd: + """Tests for _redact_cmd (Fix #4: API key leakage in logs).""" + + def test_redacts_api_key(self): + cmd = ["leann", "ask", "papers", "query", "--api-key", "sk-secret-123"] + result = _redact_cmd(cmd) + assert "sk-secret-123" not in result + assert "***" in result + + def test_redacts_embedding_api_key(self): + cmd = ["leann", "build", "papers", "--embedding-api-key", "sk-embed-456"] + result = _redact_cmd(cmd) + assert "sk-embed-456" not in result + assert "***" in result + + def test_no_false_redaction(self): + cmd = ["leann", "ask", "papers", "query", "--model", "gpt-4o"] + result = _redact_cmd(cmd) + assert "gpt-4o" in result + assert "***" not in result + + def test_key_at_end(self): + """Flag at end of list with no value should not crash.""" + cmd = ["leann", "ask", "--api-key"] + result = _redact_cmd(cmd) + assert "--api-key" in result + + def test_redacts_equals_form(self): + cmd = ["leann", "ask", "--api-key=sk-secret"] + result = _redact_cmd(cmd) + assert "sk-secret" not in result + assert "--api-key=***" in result diff --git a/tests/test_paper.py b/tests/test_paper.py index 2d50ccf..aef1f80 100644 --- a/tests/test_paper.py +++ b/tests/test_paper.py @@ -1007,3 +1007,48 @@ def test_falls_back_to_largest_file(self, tmp_path, monkeypatch): # No \begin{document}, so returns None assert result is None + + +class TestDownloadSourceSizeLimits: + """Tests for download size limits (Fix #5b: bounded download size).""" + + def test_oversized_content_length_rejected(self, tmp_path, monkeypatch): + """Content-Length header exceeding limit should cause download to be skipped.""" + from unittest.mock import MagicMock + + import requests + + mock_response = MagicMock() + mock_response.headers = {"Content-Length": str(600 * 1024 * 1024)} # 600 MB + mock_response.content = b"small" + mock_response.raise_for_status = MagicMock() + + monkeypatch.setattr(requests, "get", lambda url, timeout: mock_response) + + paper_dir = tmp_path / "test-paper" + paper_dir.mkdir() + + result = paperpipe.download_source("1234.56789", paper_dir) + assert result is None + + def test_oversized_content_rejected(self, tmp_path, monkeypatch): + """Actual content exceeding limit should cause download to be skipped.""" + from unittest.mock import MagicMock + + import requests + + # Simulate content that exceeds the limit (we'll monkeypatch the constant for test speed) + monkeypatch.setattr(paper_mod, "_MAX_DOWNLOAD_SIZE", 100) + + mock_response = MagicMock() + mock_response.headers = {} # No Content-Length + mock_response.content = b"x" * 200 # 200 bytes > 100 byte limit + mock_response.raise_for_status = MagicMock() + + monkeypatch.setattr(requests, "get", lambda url, timeout: mock_response) + + paper_dir = tmp_path / "test-paper" + paper_dir.mkdir() + + result = paperpipe.download_source("1234.56789", paper_dir) + assert result is None diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index c0b5aa3..9001e8a 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -1,5 +1,8 @@ from __future__ import annotations +import pickle +import zlib + import pytest import paperpipe @@ -84,3 +87,135 @@ def test_strips_whitespace(self): def test_single_line(self): assert paperqa._first_line("just one line") == "just one line" + + +class TestSafeUnpickle: + """Tests for _safe_unpickle (Fix #1: restricted pickle deserialization).""" + + def test_allows_valid_dict(self): + data = pickle.dumps({"key": "value", "num": 42}) + result = paperqa._safe_unpickle(data) + assert result == {"key": "value", "num": 42} + + def test_allows_empty_dict(self): + data = pickle.dumps({}) + result = paperqa._safe_unpickle(data) + assert result == {} + + def test_rejects_os_system(self): + # Build a payload using reduce that references os.system + import os as _os + + class _OsSystem: + def __reduce__(self): + return (_os.system, ("echo pwned",)) + + payload = pickle.dumps(_OsSystem()) + with pytest.raises(pickle.UnpicklingError, match="Restricted"): + paperqa._safe_unpickle(payload) + + def test_rejects_subprocess(self): + # Build a pickle payload that tries to import subprocess.call + class Evil: + def __reduce__(self): + import subprocess + + return (subprocess.call, (["echo", "pwned"],)) + + payload = pickle.dumps(Evil()) + with pytest.raises(pickle.UnpicklingError, match="Restricted"): + paperqa._safe_unpickle(payload) + + def test_integration_with_load_index_files_map(self, tmp_path): + """Integration: _paperqa_load_index_files_map uses safe unpickle.""" + mapping = {"file1.pdf": "ok", "file2.pdf": "ERROR"} + compressed = zlib.compress(pickle.dumps(mapping)) + files_zip = tmp_path / "test_index" / "files.zip" + files_zip.parent.mkdir(parents=True) + files_zip.write_bytes(compressed) + result = paperqa._paperqa_load_index_files_map(files_zip) + assert result == mapping + + +class TestValidateIndexName: + """Tests for _validate_index_name (Fix #2: path traversal prevention).""" + + def test_valid_name(self): + assert paperqa._validate_index_name("paperpipe_text-embedding-3-small") == "paperpipe_text-embedding-3-small" + + def test_rejects_empty(self): + with pytest.raises(ValueError, match="non-empty"): + paperqa._validate_index_name("") + + def test_rejects_dot_dot(self): + with pytest.raises(ValueError, match="\\.\\."): + paperqa._validate_index_name("..secret") + + def test_rejects_slash(self): + with pytest.raises(ValueError, match="path separator"): + paperqa._validate_index_name("foo/bar") + + def test_rejects_backslash(self): + with pytest.raises(ValueError, match="path separator"): + paperqa._validate_index_name("foo\\bar") + + def test_rejects_null_byte(self): + with pytest.raises(ValueError, match="null"): + paperqa._validate_index_name("foo\x00bar") + + +class TestFindCrashingFileSecurity: + """Tests for _paperqa_find_crashing_file boundary checks (Fix #3).""" + + def test_absolute_outside_rejected(self, tmp_path): + paper_dir = tmp_path / "papers" + paper_dir.mkdir() + outside = tmp_path / "secret.txt" + outside.write_text("secret") + result = paperqa._paperqa_find_crashing_file(paper_directory=paper_dir, crashing_doc=str(outside)) + assert result is None + + def test_absolute_inside_allowed(self, tmp_path): + paper_dir = tmp_path / "papers" + paper_dir.mkdir() + inside = paper_dir / "paper.pdf" + inside.write_text("pdf") + result = paperqa._paperqa_find_crashing_file(paper_directory=paper_dir, crashing_doc=str(inside)) + assert result is not None + + def test_dotdot_traversal_rejected(self, tmp_path): + paper_dir = tmp_path / "papers" + paper_dir.mkdir() + outside = tmp_path / "secret.txt" + outside.write_text("secret") + result = paperqa._paperqa_find_crashing_file(paper_directory=paper_dir, crashing_doc="../secret.txt") + assert result is None + + def test_relative_inside_allowed(self, tmp_path): + paper_dir = tmp_path / "papers" + paper_dir.mkdir() + (paper_dir / "test.pdf").write_text("pdf") + result = paperqa._paperqa_find_crashing_file(paper_directory=paper_dir, crashing_doc="test.pdf") + assert result is not None + + +class TestSafeZlibDecompress: + """Tests for _safe_zlib_decompress (Fix #5a: bounded decompression).""" + + def test_normal_data(self): + original = b"hello world" * 100 + compressed = zlib.compress(original) + result = paperqa._safe_zlib_decompress(compressed) + assert result == original + + def test_oversized_rejection(self): + # Create data larger than a 1KB limit + original = b"x" * 2048 + compressed = zlib.compress(original) + with pytest.raises(ValueError, match="exceeds"): + paperqa._safe_zlib_decompress(compressed, max_size=1024) + + def test_empty_data(self): + compressed = zlib.compress(b"") + result = paperqa._safe_zlib_decompress(compressed) + assert result == b"" From d141ad89edb9a1570c91fd9ece96e3ccf7bd802b Mon Sep 17 00:00:00 2001 From: Matthias Humt Date: Sun, 8 Feb 2026 20:12:54 +0100 Subject: [PATCH 2/2] fix: address code review findings on security PR - Validate index_name at entry of get_pqa_index_status before path construction - Use decompressobj max_length to truly bound memory during decompression - Guard int(Content-Length) against malformed headers (ValueError) - Use validated index_name in _write_index_metadata path construction Co-Authored-By: Claude Opus 4.6 --- paperpipe/paper.py | 8 ++++++-- paperpipe/paperqa.py | 17 +++++++++++------ paperpipe/paperqa_mcp_server.py | 19 ++++++++++++------- tests/test_paper.py | 23 +++++++++++++++++++++++ 4 files changed, 52 insertions(+), 15 deletions(-) diff --git a/paperpipe/paper.py b/paperpipe/paper.py index e38dcd8..bf7bc23 100644 --- a/paperpipe/paper.py +++ b/paperpipe/paper.py @@ -142,8 +142,12 @@ def download_source(arxiv_id: str, paper_dir: Path, *, extract_figures: bool = F # Check download size content_length = response.headers.get("Content-Length") - if content_length and int(content_length) > _MAX_DOWNLOAD_SIZE: - echo_warning(f"Source archive for {arxiv_id} too large ({int(content_length)} bytes). Skipping.") + try: + content_length_int = int(content_length) if content_length else None + except (ValueError, TypeError): + content_length_int = None + if content_length_int is not None and content_length_int > _MAX_DOWNLOAD_SIZE: + echo_warning(f"Source archive for {arxiv_id} too large ({content_length_int} bytes). Skipping.") return None if len(response.content) > _MAX_DOWNLOAD_SIZE: echo_warning(f"Source archive for {arxiv_id} too large ({len(response.content)} bytes). Skipping.") diff --git a/paperpipe/paperqa.py b/paperpipe/paperqa.py index 540a92a..ed4b61a 100644 --- a/paperpipe/paperqa.py +++ b/paperpipe/paperqa.py @@ -63,14 +63,19 @@ def _safe_zlib_decompress(data: bytes, *, max_size: int = _MAX_ZLIB_DECOMPRESS_S dobj = zlib.decompressobj() chunks: list[bytes] = [] total = 0 - # Feed data in 64 KB blocks to limit peak memory before aborting + # Feed compressed data in 64 KB blocks; cap decompressed output via max_length block_size = 65536 for i in range(0, len(data), block_size): - chunk = dobj.decompress(data[i : i + block_size]) - total += len(chunk) - if total > max_size: - raise ValueError(f"Decompressed data exceeds {max_size} byte limit") - chunks.append(chunk) + dobj.decompress(b"", 0) # no-op; ensures unconsumed_tail is drained below + buf = data[i : i + block_size] + while buf: + remaining_budget = max_size - total + 1 # +1 to detect overflow + chunk = dobj.decompress(buf, max_length=remaining_budget) + total += len(chunk) + if total > max_size: + raise ValueError(f"Decompressed data exceeds {max_size} byte limit") + chunks.append(chunk) + buf = dobj.unconsumed_tail # Flush remaining data remaining = dobj.flush() total += len(remaining) diff --git a/paperpipe/paperqa_mcp_server.py b/paperpipe/paperqa_mcp_server.py index 174faf0..e4cb5ab 100644 --- a/paperpipe/paperqa_mcp_server.py +++ b/paperpipe/paperqa_mcp_server.py @@ -165,11 +165,16 @@ def _safe_zlib_decompress(data: bytes, *, max_size: int = _MAX_ZLIB_DECOMPRESS_S total = 0 block_size = 65536 for i in range(0, len(data), block_size): - chunk = dobj.decompress(data[i : i + block_size]) - total += len(chunk) - if total > max_size: - raise ValueError(f"Decompressed data exceeds {max_size} byte limit") - chunks.append(chunk) + dobj.decompress(b"", 0) # no-op; ensures unconsumed_tail is drained below + buf = data[i : i + block_size] + while buf: + remaining_budget = max_size - total + 1 # +1 to detect overflow + chunk = dobj.decompress(buf, max_length=remaining_budget) + total += len(chunk) + if total > max_size: + raise ValueError(f"Decompressed data exceeds {max_size} byte limit") + chunks.append(chunk) + buf = dobj.unconsumed_tail remaining = dobj.flush() total += len(remaining) if total > max_size: @@ -241,7 +246,7 @@ def _write_index_metadata(index_root: Path, index_name: str, embedding_model: st Creates paperpipe_meta.json with embedding model, timestamp, and version. Callers should wrap in try/except to ensure index operations succeed regardless of metadata write failures. """ - _validate_index_name(index_name) + index_name = _validate_index_name(index_name) try: from importlib.metadata import PackageNotFoundError from importlib.metadata import version as get_version @@ -483,7 +488,7 @@ async def get_pqa_index_status( """Return basic status info about the PaperQA2 index (no heavy imports).""" embedding_model = (embedding_model or "").strip() or _default_embedding_model() index_root = Path(index_dir).expanduser() if index_dir else _default_index_root() - index_name = (index_name or "").strip() or _default_index_name(embedding_model) + index_name = _validate_index_name((index_name or "").strip() or _default_index_name(embedding_model)) index_path = index_root / index_name files_zip = index_path / "files.zip" mapping = _load_files_zip_map(files_zip) if files_zip.exists() else None diff --git a/tests/test_paper.py b/tests/test_paper.py index aef1f80..a408ba2 100644 --- a/tests/test_paper.py +++ b/tests/test_paper.py @@ -1031,6 +1031,29 @@ def test_oversized_content_length_rejected(self, tmp_path, monkeypatch): result = paperpipe.download_source("1234.56789", paper_dir) assert result is None + def test_malformed_content_length_ignored(self, tmp_path, monkeypatch): + """Malformed Content-Length header should not crash; download proceeds normally.""" + from unittest.mock import MagicMock + + import requests + + # Monkeypatch to small limit so the body size check triggers + monkeypatch.setattr(paper_mod, "_MAX_DOWNLOAD_SIZE", 100) + + mock_response = MagicMock() + mock_response.headers = {"Content-Length": "invalid"} + mock_response.content = b"x" * 200 # exceeds monkeypatched limit + mock_response.raise_for_status = MagicMock() + + monkeypatch.setattr(requests, "get", lambda url, timeout: mock_response) + + paper_dir = tmp_path / "test-paper" + paper_dir.mkdir() + + # Should not raise ValueError; falls through to body size check + result = paperpipe.download_source("1234.56789", paper_dir) + assert result is None + def test_oversized_content_rejected(self, tmp_path, monkeypatch): """Actual content exceeding limit should cause download to be skipped.""" from unittest.mock import MagicMock