From d42463236ad03a0a6ad22d32947bea51ad12f011 Mon Sep 17 00:00:00 2001 From: Rodrigo Rodrigues da Silva Date: Wed, 5 Nov 2025 00:58:41 +0000 Subject: [PATCH] Sanitize txtai metadata persistence --- pave/stores/txtai_store.py | 113 ++++++++++++++++++++++---- tests/conftest.py | 16 ++++ tests/test_txtai_store_sql_safety.py | 116 +++++++++++++++++++++++++++ tests/utils.py | 56 +++++++++++-- 4 files changed, 278 insertions(+), 23 deletions(-) create mode 100644 tests/test_txtai_store_sql_safety.py diff --git a/pave/stores/txtai_store.py b/pave/stores/txtai_store.py index 31280b2..be2c3bb 100644 --- a/pave/stores/txtai_store.py +++ b/pave/stores/txtai_store.py @@ -4,7 +4,7 @@ from __future__ import annotations import os, json, operator from datetime import datetime -from typing import Dict, Iterable, List, Any +from typing import Any, Dict, Iterable, List, Optional from threading import Lock from contextlib import contextmanager from txtai.embeddings import Embeddings @@ -12,6 +12,14 @@ from pave.config import CFG as c, LOG as log _LOCKS : dict[str, Lock] = {} +_SQL_TRANS = str.maketrans({ + ";": " ", + '"': " ", + "`": " ", + "\\": " ", + "\x00": "", +}) + def get_lock(key: str) -> Lock: if key not in _LOCKS: _LOCKS[key] = Lock() @@ -236,10 +244,10 @@ def index_records(self, tenant: str, collection: str, docid: str, md["docid"] = docid try: - meta_json = json.dumps(md, ensure_ascii=False) - md = json.loads(meta_json) - except: - md = {} + safe_meta = self._sanit_meta_dict(md) + meta_json = json.dumps(safe_meta, ensure_ascii=False) + except Exception: + safe_meta = {} meta_json = "" rid = str(rid) @@ -247,9 +255,11 @@ def index_records(self, tenant: str, collection: str, docid: str, if not rid.startswith(f"{docid}::"): rid = f"{docid}::{rid}" - meta_side[rid] = md + md_for_index = {k: v for k, v in safe_meta.items() if k != "text"} + + meta_side[rid] = safe_meta record_ids.append(rid) - prepared.append((rid, {"text":txt, **md}, meta_json)) + prepared.append((rid, {"text": txt, **md_for_index}, meta_json)) self._save_chunk_text(tenant, collection, rid, txt) assert txt == (self._load_chunk_text(tenant, collection, rid) or "") @@ -280,10 +290,15 @@ def _matches_filters(m: Dict[str, Any], if not filters: return True - def match(have: Any, cond: str) -> bool: + def match(have: Any, cond: Any) -> bool: if have is None: return False - s = str(cond) + if isinstance(have, (list, tuple, set)): + return any(match(item, cond) for item in have) + if isinstance(cond, str): + s = TxtaiStore._sanit_sql(cond) + else: + s = str(cond) hv = str(have) # Numeric/date ops for op in (">=", "<=", "!=", ">", "<"): @@ -313,7 +328,7 @@ def match(have: Any, cond: str) -> bool: return hv == s for k, vals in filters.items(): - if not any(match(m.get(k), v) for v in vals): + if not any(match(TxtaiStore._lookup_meta(m, k), v) for v in vals): return False return True @@ -325,6 +340,9 @@ def _split_filters(filters: dict[str, Any] | None) -> tuple[dict, dict]: pre_f, pos_f = {}, {} for key, vals in (filters or {}).items(): + safe_key = TxtaiStore._sanit_field(key) + if not safe_key: + continue if not isinstance(vals, list): vals = [vals] exacts, extended = [], [] @@ -338,12 +356,68 @@ def _split_filters(filters: dict[str, Any] | None) -> tuple[dict, dict]: else: exacts.append(v) if exacts: - pre_f[key] = exacts + pre_f[safe_key] = exacts if extended: - pos_f[key] = extended + pos_f[safe_key] = extended log.debug(f"after split: PRE {pre_f} POS {pos_f}") return pre_f, pos_f + @staticmethod + def _lookup_meta(meta: Dict[str, Any] | None, key: str) -> Any: + if not meta: + return None + if key in meta: + return meta.get(key) + for raw_key, value in meta.items(): + if TxtaiStore._sanit_field(raw_key) == key: + return value + return None + + @staticmethod + def _sanit_meta_value(value: Any) -> Any: + if isinstance(value, dict): + return TxtaiStore._sanit_meta_dict(value) + if isinstance(value, (list, tuple, set)): + return [TxtaiStore._sanit_meta_value(v) for v in value] + if isinstance(value, (int, float, bool)) or value is None: + return value + return TxtaiStore._sanit_sql(value) + + @staticmethod + def _sanit_meta_dict(meta: Dict[str, Any] | None) -> Dict[str, Any]: + safe: Dict[str, Any] = {} + if not isinstance(meta, dict): + return safe + for raw_key, raw_value in meta.items(): + safe_key = TxtaiStore._sanit_field(raw_key) + if not safe_key or safe_key == "text": + continue + safe[safe_key] = TxtaiStore._sanit_meta_value(raw_value) + return safe + + @staticmethod + def _sanit_sql(value: Any, *, max_len: Optional[int] = None) -> str: + if value is None: + return "" + text = str(value).translate(_SQL_TRANS) + for token in ("--", "/*", "*/"): + if token in text: + text = text.split(token, 1)[0] + text = text.strip() + if max_len is not None and max_len > 0 and len(text) > max_len: + text = text[:max_len] + return text.replace("'", "''") + + @staticmethod + def _sanit_field(name: Any) -> str: + if not isinstance(name, str): + name = str(name) + safe = [] + for ch in name: + if ch.isalnum() or ch in {"_", "-"}: + safe.append(ch) + return "".join(safe) + @staticmethod def _build_sql(query: str, k: int, filters: dict[str, Any], columns: list[str], with_similarity: bool = True, avoid_duplicates = True) -> str: @@ -356,14 +430,23 @@ def _build_sql(query: str, k: int, filters: dict[str, Any], columns: list[str], wheres = [] if with_similarity and query: - q_safe = query.replace("'", "''") + max_len_cfg = c.get("vector_store.txtai.max_query_chars", 512) + try: + max_len = int(max_len_cfg) + except (TypeError, ValueError): + max_len = 512 + limit = max_len if max_len > 0 else None + q_safe = TxtaiStore._sanit_sql(query, max_len=limit) wheres.append(f"similar('{q_safe}')") for key, vals in filters.items(): + safe_key = TxtaiStore._sanit_field(key) + if not safe_key: + continue ors = [] for v in vals: - safe_v = str(v).replace("'", "''") - ors.append(f"[{key}] = '{safe_v}'") + safe_v = TxtaiStore._sanit_sql(v) + ors.append(f"[{safe_key}] = '{safe_v}'") or_safe = " OR ".join(ors) wheres.append(f"({or_safe})") diff --git a/tests/conftest.py b/tests/conftest.py index bf4e8b0..8fd621b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,22 @@ # (C) 2025 Rodrigo Rodrigues da Silva # SPDX-License-Identifier: GPL-3.0-or-later +import sys +import types + +if "txtai.embeddings" not in sys.modules: + txtai_stub = types.ModuleType("txtai") + embeddings_stub = types.ModuleType("txtai.embeddings") + + class _StubEmbeddings: # pragma: no cover - stub for optional dependency + def __init__(self, *args, **kwargs): + pass + + embeddings_stub.Embeddings = _StubEmbeddings + txtai_stub.embeddings = embeddings_stub + sys.modules.setdefault("txtai", txtai_stub) + sys.modules.setdefault("txtai.embeddings", embeddings_stub) + import pytest from fastapi.testclient import TestClient from pave.config import get_cfg, reload_cfg diff --git a/tests/test_txtai_store_sql_safety.py b/tests/test_txtai_store_sql_safety.py new file mode 100644 index 0000000..e230929 --- /dev/null +++ b/tests/test_txtai_store_sql_safety.py @@ -0,0 +1,116 @@ +# (C) 2025 Rodrigo Rodrigues da Silva +# SPDX-License-Identifier: GPL-3.0-or-later + +import json + +import pytest + +from pave.stores import txtai_store as store_mod +from pave.stores.txtai_store import TxtaiStore +from pave.config import get_cfg +from utils import FakeEmbeddings + + +@pytest.fixture(autouse=True) +def _fake_embeddings(monkeypatch): + monkeypatch.setattr(store_mod, "Embeddings", FakeEmbeddings, raising=True) + + +@pytest.fixture() +def store(): + return TxtaiStore() + + +def _extract_similarity_term(sql: str) -> str: + marker = "similar('" + if marker not in sql: + raise AssertionError(f"similar() clause missing in SQL: {sql!r}") + rest = sql.split(marker, 1)[1] + return rest.split("')", 1)[0] + + +def test_build_sql_sanitizes_similarity_term(store): + raw_query = "foo'; DROP TABLE users; -- comment" + sql = store._build_sql(raw_query, 5, {}, ["id", "text"]) + term = _extract_similarity_term(sql) + + # injection primitives are stripped or neutralised + assert ";" not in term + assert "--" not in term + # original alpha characters remain so search still works + assert "foo" in term + + +def test_build_sql_sanitizes_filter_values(store): + filters = {"lang": ["en'; DELETE FROM x;"], "tags": ['alpha"beta']} + sql = store._build_sql("foo", 5, filters, ["id", "text"]) + + # filter clause should not leak dangerous characters + assert ";" not in sql + assert '"' not in sql + assert "--" not in sql + + +def test_build_sql_normalises_filter_keys(store): + filters = {"lang]; DROP": ["en"], 123: ["x"]} + sql = store._build_sql("foo", 5, filters, ["id"]) + assert "[langDROP]" in sql + assert "[123]" in sql + + +def test_build_sql_applies_query_length_limit(store): + cfg = get_cfg() + snapshot = cfg.snapshot() + try: + cfg.set("vector_store.txtai.max_query_chars", 8) + sql = store._build_sql("abcdefghijklmno", 5, {}, ["id"]) + term = _extract_similarity_term(sql) + + # collapse the doubled quotes to measure the original payload length + collapsed = term.replace("''", "'") + assert len(collapsed) == 8 + finally: + cfg.replace(data=snapshot) + + +def test_search_handles_special_characters(store): + tenant, collection = "tenant", "coll" + store.load_or_init(tenant, collection) + + records = [("r1", "hello world", {"lang": "en"})] + store.index_records(tenant, collection, "doc", records) + + hits = store.search(tenant, collection, "world; -- comment", k=5) + assert hits + assert hits[0]["id"].endswith("::r1") + + +def test_round_trip_with_weird_metadata_field(store): + tenant, collection = "tenant", "coll" + store.load_or_init(tenant, collection) + + weird_key = "meta;`DROP" + weird_value = "val'u" + records = [("r2", "strange world", {weird_key: weird_value})] + store.index_records(tenant, collection, "doc2", records) + + filters = {weird_key: weird_value} + hits = store.search(tenant, collection, "strange", k=5, filters=filters) + + assert hits + assert hits[0]["id"].endswith("::r2") + + emb = store._emb[(tenant, collection)] + safe_key = TxtaiStore._sanit_field(weird_key) + assert emb.last_sql and f"[{safe_key}]" in emb.last_sql + + rid = hits[0]["id"] + stored_meta = store._load_meta(tenant, collection).get(rid) or {} + assert safe_key in stored_meta + assert stored_meta[safe_key] == TxtaiStore._sanit_sql(weird_value) + + doc = emb._docs[rid] + assert doc["meta"].get(safe_key) == TxtaiStore._sanit_sql(weird_value) + serialized = json.loads(doc["meta_json"]) if doc.get("meta_json") else {} + assert serialized.get(safe_key) == TxtaiStore._sanit_sql(weird_value) + assert hits[0]["meta"].get(safe_key) == TxtaiStore._sanit_sql(weird_value) diff --git a/tests/utils.py b/tests/utils.py index 4acbb3c..4b77c35 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -10,18 +10,26 @@ class FakeEmbeddings: """Tiny in-memory index. Keeps interface you use in tests.""" def __init__(self, config): # config unused - self._docs = {} # rid -> (text, meta_json) + self._docs = {} # rid -> {"text": str, "meta_json": str, "meta": dict} + self.last_sql = None def index(self, docs): - for rid, text, meta_json in docs: + for rid, payload, meta_json in docs: assert isinstance(meta_json, str) - self._docs[rid] = (text, meta_json) + if isinstance(payload, dict): + text = payload.get("text") + meta = {k: v for k, v in payload.items() if k != "text"} + else: + text = payload + meta = {} + self._docs[rid] = {"text": text, "meta_json": meta_json, "meta": meta} def upsert(self, docs): return self.index(docs) def search(self, sql, k=5): import re + self.last_sql = sql term = None m = re.search(r"similar\('([^']+)'", sql) if m: @@ -32,10 +40,42 @@ def search(self, sql, k=5): if not term: return [] - hits = [ - {"id": rid, "score": 1.0, "text": txt, "tags": {"docid": "DUMMY"}} - for rid, (txt, _) in self._docs.items() if term in str(txt).lower() - ] + filter_pairs = re.findall(r"\[([^\]]+)\]\s*=\s*'((?:''|[^'])*)'", sql) + + hits = [] + for rid, entry in self._docs.items(): + text = entry.get("text") + if text is None: + continue + if term not in str(text).lower(): + continue + + metadata = entry.get("meta") or {} + include = True + for field, raw_val in filter_pairs: + stored = metadata.get(field) + if stored is None: + include = False + break + expected = raw_val + if isinstance(stored, (list, tuple, set)): + options = {str(v) for v in stored} + if expected not in options: + include = False + break + else: + if str(stored) != expected: + include = False + break + if not include: + continue + + hits.append({ + "id": rid, + "score": 1.0, + "text": text, + "docid": metadata.get("docid"), + }) return hits[:10] """ q = (query or "").lower() @@ -47,7 +87,7 @@ def search(self, sql, k=5): """ def lookup(self, ids): - return {rid: self._docs.get(rid, ("", ""))[0] for rid in ids} + return {rid: (self._docs.get(rid) or {}).get("text") for rid in ids} def delete(self, ids): for rid in ids: