Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paperpipe/cli/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,9 @@ def ask(
# Only remove files from paperpipe's managed staging directory.
# Never delete from a user-provided paper directory.
managed_staging_dir = (config.PAPER_DB / ".pqa_papers").expanduser()
if paper_dir.resolve() == managed_staging_dir.resolve():
if paper_dir.resolve() == managed_staging_dir.resolve() and f.resolve().is_relative_to(
managed_staging_dir.resolve()
):
try:
f.unlink()
echo_warning(f"Removed '{crashing_doc}' from PaperQA2 staging to prevent re-indexing.")
Expand Down
33 changes: 27 additions & 6 deletions paperpipe/leann.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
default_leann_llm_provider,
)
from .output import debug, echo_error, echo_progress, echo_warning
from .paperqa import _validate_index_name

# -----------------------------------------------------------------------------
# Manifest types and I/O for incremental indexing
Expand Down Expand Up @@ -69,7 +70,7 @@ class IndexDelta:

def _leann_manifest_path(index_name: str) -> Path:
"""Path to paperpipe's incremental indexing manifest."""
return config.PAPER_DB / ".leann" / "indexes" / index_name / "paperpipe_manifest.json"
return config.PAPER_DB / ".leann" / "indexes" / _validate_index_name(index_name) / "paperpipe_manifest.json"


def _load_leann_manifest(index_name: str) -> Optional[LeannManifest]:
Expand Down Expand Up @@ -250,7 +251,7 @@ def _leann_incremental_update(
except ImportError as e:
raise IncrementalUpdateError(f"LEANN Python API not available: {e}") from e

index_dir = config.PAPER_DB / ".leann" / "indexes" / index_name
index_dir = config.PAPER_DB / ".leann" / "indexes" / _validate_index_name(index_name)
index_path = index_dir / "documents.leann"

index_file = index_dir / f"{index_path.stem}.index"
Expand Down Expand Up @@ -436,6 +437,26 @@ def _leann_incremental_update(
return added, delta.unchanged_count, errors


_REDACT_FLAGS = {"--api-key", "--embedding-api-key"}


def _redact_cmd(cmd: list[str]) -> str:
"""Return a shell-safe string representation of cmd with API keys redacted."""
redacted = list(cmd)
i = 0
while i < len(redacted):
if redacted[i] in _REDACT_FLAGS and i + 1 < len(redacted):
redacted[i + 1] = "***"
i += 2
elif any(redacted[i].startswith(f"{flag}=") for flag in _REDACT_FLAGS):
flag = redacted[i].split("=", 1)[0]
redacted[i] = f"{flag}=***"
i += 1
else:
i += 1
return shlex.join(redacted)


def _extract_arg_value(args: list[str], flag: str) -> Optional[str]:
"""Extract the value for a CLI flag from args list."""
for i, arg in enumerate(args):
Expand All @@ -452,7 +473,7 @@ def _extract_arg_value(args: list[str], flag: str) -> Optional[str]:


def _leann_index_meta_path(index_name: str) -> Path:
return config.PAPER_DB / ".leann" / "indexes" / index_name / "documents.leann.meta.json"
return config.PAPER_DB / ".leann" / "indexes" / _validate_index_name(index_name) / "documents.leann.meta.json"


def _load_leann_backend_name(index_name: str) -> Optional[str]:
Expand Down Expand Up @@ -598,11 +619,11 @@ def _leann_build_index(
embedding_mode_for_meta = embedding_mode_override or default_leann_embedding_mode()

cmd.extend(extra_args)
debug("Running LEANN: %s", shlex.join(cmd))
debug("Running LEANN: %s", _redact_cmd(cmd))
proc = subprocess.run(cmd, cwd=config.PAPER_DB)
if proc.returncode != 0:
echo_error(f"LEANN command failed (exit code {proc.returncode})")
echo_error(f"Command: {shlex.join(cmd)}")
echo_error(f"Command: {_redact_cmd(cmd)}")
raise SystemExit(proc.returncode)

# Write metadata on success
Expand Down Expand Up @@ -704,7 +725,7 @@ def _ask_leann(
cmd.extend(["--thinking-budget", thinking_budget])

cmd.extend(extra_args)
debug("Running LEANN: %s", shlex.join(cmd))
debug("Running LEANN: %s", _redact_cmd(cmd))

if interactive:
proc = subprocess.run(cmd, cwd=config.PAPER_DB)
Expand Down
39 changes: 36 additions & 3 deletions paperpipe/paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
from .output import debug, echo_error, echo_progress, echo_success, echo_warning
from .search import _maybe_update_search_index

_MAX_DOWNLOAD_SIZE = 500 * 1024 * 1024 # 500 MB
_MAX_TAR_MEMBER_SIZE = 100 * 1024 * 1024 # 100 MB
_MAX_TAR_TOTAL_SIZE = 500 * 1024 * 1024 # 500 MB


def fetch_arxiv_metadata(arxiv_id: str) -> dict:
"""Fetch paper metadata from arXiv API."""
Expand Down Expand Up @@ -136,6 +140,19 @@ def download_source(arxiv_id: str, paper_dir: Path, *, extract_figures: bool = F
echo_warning(f"Could not download source for {arxiv_id}: {e}")
return None

# Check download size
content_length = response.headers.get("Content-Length")
try:
content_length_int = int(content_length) if content_length else None
except (ValueError, TypeError):
content_length_int = None
if content_length_int is not None and content_length_int > _MAX_DOWNLOAD_SIZE:
echo_warning(f"Source archive for {arxiv_id} too large ({content_length_int} bytes). Skipping.")
return None
if len(response.content) > _MAX_DOWNLOAD_SIZE:
echo_warning(f"Source archive for {arxiv_id} too large ({len(response.content)} bytes). Skipping.")
return None

# Save and extract tarball
with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as f:
f.write(response.content)
Expand All @@ -145,14 +162,21 @@ def download_source(arxiv_id: str, paper_dir: Path, *, extract_figures: bool = F
try:
# Try to open as tar (most common)
with tarfile.open(tar_path) as tar:
tex_members = [m for m in tar.getmembers() if m.isfile() and m.name.endswith(".tex")]
tex_members = [
m for m in tar.getmembers() if m.isfile() and m.name.endswith(".tex") and m.size <= _MAX_TAR_MEMBER_SIZE
]
if tex_members:
tex_by_name: dict[str, str] = {}
total_extracted = 0
for member in tex_members:
if total_extracted + member.size > _MAX_TAR_TOTAL_SIZE:
break
extracted = tar.extractfile(member)
if not extracted:
continue
tex_by_name[member.name] = extracted.read().decode("utf-8", errors="replace")
content = extracted.read()
total_extracted += len(content)
tex_by_name[member.name] = content.decode("utf-8", errors="replace")

preferred_names = ("main.tex", "paper.tex")
preferred = [n for n in tex_by_name if Path(n).name in preferred_names]
Expand Down Expand Up @@ -288,6 +312,7 @@ def _extract_figures_from_latex(tex_content: str, tar: tarfile.TarFile, paper_di

figures_dir = paper_dir / "figures"
extracted_count = 0
total_extracted_size = 0

for ref in matches:
# LaTeX allows omitting file extension, so we need to try with and without
Expand All @@ -311,10 +336,16 @@ def _extract_figures_from_latex(tex_content: str, tar: tarfile.TarFile, paper_di
for member in tar.getmembers():
if not member.isfile():
continue
if member.size > _MAX_TAR_MEMBER_SIZE:
continue

member_name = member.name.replace("\\", "/")
# Check if member path ends with candidate or matches exactly
if member_name == candidate_path or member_name.endswith("/" + candidate_path):
if total_extracted_size + member.size > _MAX_TAR_TOTAL_SIZE:
echo_warning(" Tar extraction total size limit reached")
return extracted_count

# Extract the file
if not figures_dir.exists():
try:
Expand All @@ -331,7 +362,9 @@ def _extract_figures_from_latex(tex_content: str, tar: tarfile.TarFile, paper_di
dest_path = figures_dir / dest_name

# Write file
dest_path.write_bytes(file_obj.read())
data = file_obj.read()
total_extracted_size += len(data)
dest_path.write_bytes(data)
extracted_count += 1
found = True
break
Expand Down
113 changes: 104 additions & 9 deletions paperpipe/paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import io
import os
import pickle
import re
Expand All @@ -23,6 +24,84 @@ def _pillow_available() -> bool:
return importlib.util.find_spec("PIL") is not None


_SAFE_PICKLE_BUILTINS = frozenset({"dict", "str", "list", "tuple", "set", "frozenset", "int", "float", "bool", "bytes"})


class _RestrictedUnpickler(pickle.Unpickler):
"""Unpickler that only allows basic Python builtins (no arbitrary code execution)."""

def find_class(self, module: str, name: str) -> Any:
if module == "builtins" and name in _SAFE_PICKLE_BUILTINS:
return (
getattr(__builtins__ if isinstance(__builtins__, dict) else __builtins__, name, None)
or {
"dict": dict,
"str": str,
"list": list,
"tuple": tuple,
"set": set,
"frozenset": frozenset,
"int": int,
"float": float,
"bool": bool,
"bytes": bytes,
}[name]
)
raise pickle.UnpicklingError(f"Restricted: {module}.{name}")


def _safe_unpickle(data: bytes) -> Any:
"""Deserialize pickle data using a restricted unpickler (blocks arbitrary code execution)."""
return _RestrictedUnpickler(io.BytesIO(data)).load()


_MAX_ZLIB_DECOMPRESS_SIZE = 100 * 1024 * 1024 # 100 MB


def _safe_zlib_decompress(data: bytes, *, max_size: int = _MAX_ZLIB_DECOMPRESS_SIZE) -> bytes:
"""Decompress zlib data with a size limit to prevent decompression bombs."""
dobj = zlib.decompressobj()
chunks: list[bytes] = []
total = 0
# Feed compressed data in 64 KB blocks; cap decompressed output via max_length
block_size = 65536
for i in range(0, len(data), block_size):
dobj.decompress(b"", 0) # no-op; ensures unconsumed_tail is drained below
buf = data[i : i + block_size]
while buf:
remaining_budget = max_size - total + 1 # +1 to detect overflow
chunk = dobj.decompress(buf, max_length=remaining_budget)
total += len(chunk)
if total > max_size:
raise ValueError(f"Decompressed data exceeds {max_size} byte limit")
chunks.append(chunk)
buf = dobj.unconsumed_tail
# Flush remaining data
remaining = dobj.flush()
total += len(remaining)
if total > max_size:
raise ValueError(f"Decompressed data exceeds {max_size} byte limit")
chunks.append(remaining)
return b"".join(chunks)


def _validate_index_name(index_name: str) -> str:
"""Validate and return a safe index name (rejects path traversal attempts)."""
name = (index_name or "").strip()
if not name:
raise ValueError("index_name must be non-empty")
if "\x00" in name:
raise ValueError("index_name must not contain null bytes")
if "/" in name or "\\" in name:
raise ValueError("index_name must not contain path separators")
if ".." in name:
raise ValueError("index_name must not contain '..' components")
# Also reject pure dots
if all(c == "." for c in name):
raise ValueError("index_name must not be only dots")
return name


def _refresh_pqa_pdf_staging_dir(*, staging_dir: Path, exclude_names: Optional[set[str]] = None) -> int:
"""
Create/update a flat directory containing only PDFs (one per paper) for PaperQA2 indexing.
Expand Down Expand Up @@ -138,17 +217,29 @@ def _paperqa_find_crashing_file(*, paper_directory: Path, crashing_doc: str) ->
if not doc:
return None

resolved_parent = paper_directory.resolve()

def _within_parent(p: Path) -> Optional[Path]:
"""Return resolved path if it's within paper_directory, else None."""
try:
rp = p.resolve()
if rp.is_relative_to(resolved_parent) and rp.exists():
return rp
except (OSError, ValueError):
pass
return None

doc_path = Path(doc)
if doc_path.is_absolute():
return doc_path if doc_path.exists() else None
return _within_parent(doc_path)

if ".." in doc_path.parts:
doc_path = Path(doc_path.name)

# Try the path as-is (relative to the paper directory).
candidate = paper_directory / doc_path
if candidate.exists():
return candidate
result = _within_parent(paper_directory / doc_path)
if result is not None:
return result

# Try matching by file name/stem (common when pqa prints just "foo.pdf" or "foo").
name = doc_path.name
Expand All @@ -159,29 +250,33 @@ def _paperqa_find_crashing_file(*, paper_directory: Path, crashing_doc: str) ->
try:
for f in paper_directory.iterdir():
if f.name == name or f.stem == expected_stem:
return f
checked = _within_parent(f)
if checked is not None:
return checked
except OSError:
pass

# As a last resort, search recursively by filename.
try:
for f in paper_directory.rglob(name):
if f.name == name:
return f
checked = _within_parent(f)
if checked is not None:
return checked
except OSError:
pass

return None


def _paperqa_index_files_path(*, index_directory: Path, index_name: str) -> Path:
return Path(index_directory) / index_name / "files.zip"
return Path(index_directory) / _validate_index_name(index_name) / "files.zip"


def _paperqa_load_index_files_map(path: Path) -> Optional[dict[str, str]]:
try:
raw = zlib.decompress(path.read_bytes())
obj = pickle.loads(raw)
raw = _safe_zlib_decompress(path.read_bytes())
obj = _safe_unpickle(raw)
except Exception:
return None
if not isinstance(obj, dict):
Expand Down
Loading