From 03c34344b1e14177950c5d41b55b21e75f446da0 Mon Sep 17 00:00:00 2001 From: Karol Charchut <59970980+Quarol@users.noreply.github.com> Date: Sun, 15 Feb 2026 23:14:46 +0100 Subject: [PATCH 01/16] add version flag to CLI. --- llmsql/__main__.py | 16 ++++++++++++++++ llmsql/config/config.py | 12 +++++++++++- llmsql/inference/inference_transformers.py | 19 ++++++++++++++++--- llmsql/inference/inference_vllm.py | 21 ++++++++++++++++++--- llmsql/utils/evaluation_utils.py | 10 +++++++--- llmsql/utils/inference_utils.py | 14 +++++++++----- tests/utils/test_inference_utils.py | 6 +++--- 7 files changed, 80 insertions(+), 18 deletions(-) diff --git a/llmsql/__main__.py b/llmsql/__main__.py index ef52f4e..30b5c80 100644 --- a/llmsql/__main__.py +++ b/llmsql/__main__.py @@ -1,6 +1,7 @@ import argparse import inspect import json +from llmsql.config.config import get_available_versions, DEFAULT_LLMSQL_VERSION def main() -> None: @@ -45,6 +46,13 @@ def main() -> None: --output-file outputs/temp_0.9.jsonl \ --temperature 0.9 \ --generation-kwargs '{"do_sample": true, "top_p": 0.9, "top_k": 40}' + + # 6️⃣ Specify llmsql version (2.0 by default) + llmsql inference --version 1.0 --method transformers \ + --model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \ + --output-file outputs/preds_transformers.jsonl \ + --batch-size 8 \ + --num-fewshots 5 """ inf_parser = subparsers.add_parser( @@ -64,6 +72,14 @@ def main() -> None: help="Inference backend to use ('transformers' or 'vllm').", ) + inf_parser.add_argument( + "--version", + type=str, + default=DEFAULT_LLMSQL_VERSION, + choices=get_available_versions(), + help="Run inference using available version of LLMSQL (2.0 by default)", + ) + # ================================================================ # Parse CLI # ================================================================ diff --git a/llmsql/config/config.py b/llmsql/config/config.py index 24c9f0e..d8f9e36 100644 --- a/llmsql/config/config.py +++ b/llmsql/config/config.py @@ -1,2 +1,12 @@ -REPO_ID = "llmsql-bench/llmsql-benchmark" +REPO_IDs: dict = { + "1.0": "llmsql-bench/llmsql-benchmark", + "2.0": "llmsql-bench/llmsql-2.0" +} +DEFAULT_LLMSQL_VERSION = "2.0" DEFAULT_WORKDIR_PATH = "llmsql_workdir" + +def get_repo_id(version: str = DEFAULT_LLMSQL_VERSION) -> str: + return REPO_IDs[version] + +def get_available_versions() -> list[str]: + return list(REPO_IDs.keys()) \ No newline at end of file diff --git a/llmsql/inference/inference_transformers.py b/llmsql/inference/inference_transformers.py index 4d66b90..ae2bc9a 100644 --- a/llmsql/inference/inference_transformers.py +++ b/llmsql/inference/inference_transformers.py @@ -14,6 +14,7 @@ results = inference_transformers( model_or_model_name_or_path="Qwen/Qwen2.5-1.5B-Instruct", + repo_id="llmsql-bench/llmsql-2.0", output_file="outputs/preds_transformers.jsonl", questions_path="data/questions.jsonl", tables_path="data/tables.jsonl", @@ -46,7 +47,7 @@ from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer -from llmsql.config.config import DEFAULT_WORKDIR_PATH +from llmsql.config.config import DEFAULT_WORKDIR_PATH, DEFAULT_LLMSQL_VERSION, get_repo_id from llmsql.loggers.logging_config import log from llmsql.utils.inference_utils import _maybe_download, _setup_seed from llmsql.utils.utils import ( @@ -85,6 +86,7 @@ def inference_transformers( top_k: int = 50, generation_kwargs: dict[str, Any] | None = None, # --- Benchmark Parameters --- + version: str = DEFAULT_LLMSQL_VERSION, output_file: str = "llm_sql_predictions.jsonl", questions_path: str | None = None, tables_path: str | None = None, @@ -128,6 +130,7 @@ def inference_transformers( 'top_p', 'top_k' are handled separately. # Benchmark: + version: LLMSQL version output_file: Output JSONL file path for completions. questions_path: Path to benchmark questions JSONL. tables_path: Path to benchmark tables JSONL. @@ -208,8 +211,18 @@ def inference_transformers( model.eval() # --- Load necessary files --- - questions_path = _maybe_download("questions.jsonl", questions_path) - tables_path = _maybe_download("tables.jsonl", tables_path) + repo_id = get_repo_id(version) + + questions_path = _maybe_download( + "questions.jsonl", + questions_path, + repo_id, + ) + tables_path = _maybe_download( + "tables.jsonl", + tables_path, + repo_id, + ) questions = load_jsonl(questions_path) tables_list = load_jsonl(tables_path) diff --git a/llmsql/inference/inference_vllm.py b/llmsql/inference/inference_vllm.py index bbeb2a9..d42e413 100644 --- a/llmsql/inference/inference_vllm.py +++ b/llmsql/inference/inference_vllm.py @@ -14,6 +14,7 @@ results = inference_vllm( model_name="Qwen/Qwen2.5-1.5B-Instruct", + version="2.0", output_file="outputs/predictions.jsonl", questions_path="data/questions.jsonl", tables_path="data/tables.jsonl", @@ -47,7 +48,7 @@ from tqdm import tqdm from vllm import LLM, SamplingParams -from llmsql.config.config import DEFAULT_WORKDIR_PATH +from llmsql.config.config import DEFAULT_WORKDIR_PATH, get_repo_id, DEFAULT_LLMSQL_VERSION from llmsql.loggers.logging_config import log from llmsql.utils.inference_utils import _maybe_download, _setup_seed from llmsql.utils.utils import ( @@ -75,6 +76,7 @@ def inference_vllm( do_sample: bool = True, sampling_kwargs: dict[str, Any] | None = None, # === Benchmark Parameters === + version: str = DEFAULT_LLMSQL_VERSION, output_file: str = "llm_sql_predictions.jsonl", questions_path: str | None = None, tables_path: str | None = None, @@ -107,6 +109,7 @@ def inference_vllm( separately and will override values here. # Benchmark: + version: LLMSQL version output_file: Path to write outputs (will be overwritten). questions_path: Path to questions.jsonl (auto-downloads if missing). tables_path: Path to tables.jsonl (auto-downloads if missing). @@ -129,8 +132,20 @@ def inference_vllm( # --- load input data --- log.info("Preparing questions and tables...") - questions_path = _maybe_download("questions.jsonl", questions_path) - tables_path = _maybe_download("tables.jsonl", tables_path) + + repo_id = get_repo_id(version) + + questions_path = _maybe_download( + "questions.jsonl", + questions_path, + repo_id, + ) + tables_path = _maybe_download( + "tables.jsonl", + tables_path, + repo_id, + ) + questions = load_jsonl(questions_path) tables_list = load_jsonl(tables_path) tables = {t["table_id"]: t for t in tables_list} diff --git a/llmsql/utils/evaluation_utils.py b/llmsql/utils/evaluation_utils.py index c00fb7b..be394fd 100644 --- a/llmsql/utils/evaluation_utils.py +++ b/llmsql/utils/evaluation_utils.py @@ -5,7 +5,7 @@ from huggingface_hub import hf_hub_download -from llmsql.config.config import REPO_ID +from llmsql.config.config import DEFAULT_REPO_ID from llmsql.loggers.logging_config import log from llmsql.utils.regex_extractor import find_sql @@ -170,10 +170,14 @@ def evaluate_sample( ) -def download_benchmark_file(filename: str, local_dir: Path) -> str: +def download_benchmark_file( + filename: str, + local_dir: Path, + repo_id: str = DEFAULT_REPO_ID, +) -> str: """Download a benchmark file from HuggingFace Hub.""" file_path = hf_hub_download( - repo_id=REPO_ID, + repo_id=repo_id, filename=filename, repo_type="dataset", local_dir=local_dir, diff --git a/llmsql/utils/inference_utils.py b/llmsql/utils/inference_utils.py index ee7d300..7130aa7 100644 --- a/llmsql/utils/inference_utils.py +++ b/llmsql/utils/inference_utils.py @@ -5,14 +5,14 @@ import numpy as np import torch -from llmsql.config.config import DEFAULT_WORKDIR_PATH, REPO_ID +from llmsql.config.config import DEFAULT_WORKDIR_PATH from llmsql.loggers.logging_config import log # --- Load benchmark data --- -def _download_file(filename: str) -> str: +def _download_file(filename: str, repo_id: str) -> str: path = hf_hub_download( - repo_id=REPO_ID, + repo_id=repo_id, filename=filename, repo_type="dataset", local_dir=DEFAULT_WORKDIR_PATH, @@ -29,14 +29,18 @@ def _setup_seed(seed: int) -> None: torch.cuda.manual_seed_all(seed) -def _maybe_download(filename: str, local_path: str | None) -> str: +def _maybe_download( + filename: str, + local_path: str | None, + repo_id: str +) -> str: if local_path is not None: return local_path target_path = Path(DEFAULT_WORKDIR_PATH) / filename if not target_path.exists(): log.info(f"Downloading {filename} from Hugging Face Hub...") local_path = hf_hub_download( - repo_id=REPO_ID, + repo_id=repo_id, filename=filename, repo_type="dataset", local_dir=DEFAULT_WORKDIR_PATH, diff --git a/tests/utils/test_inference_utils.py b/tests/utils/test_inference_utils.py index 634c26d..afbe5ef 100644 --- a/tests/utils/test_inference_utils.py +++ b/tests/utils/test_inference_utils.py @@ -6,7 +6,7 @@ import pytest import torch -from llmsql.config.config import DEFAULT_WORKDIR_PATH, REPO_ID +from llmsql.config.config import DEFAULT_WORKDIR_PATH, get_repo_id from llmsql.utils import inference_utils as mod @@ -16,7 +16,7 @@ async def test_download_file(monkeypatch, tmp_path): expected_path = str(tmp_path / "questions.jsonl") def fake_hf_hub_download(repo_id, filename, repo_type, local_dir): - assert repo_id == REPO_ID + assert repo_id == get_repo_id() assert repo_type == "dataset" assert local_dir == DEFAULT_WORKDIR_PATH assert filename == "questions.jsonl" @@ -80,5 +80,5 @@ def fake_hf_hub_download(**kwargs): path = mod._maybe_download(filename, local_path=None) assert Path(path).exists() - assert called["repo_id"] == REPO_ID + assert called["repo_id"] == get_repo_id() assert called["filename"] == filename From 4936942a1e5c399e145bcbd5bd2fab44ba4c0b19 Mon Sep 17 00:00:00 2001 From: Karol Charchut <59970980+Quarol@users.noreply.github.com> Date: Sun, 15 Feb 2026 23:45:43 +0100 Subject: [PATCH 02/16] version flag added to evaluation --- llmsql/evaluation/evaluate.py | 10 +++++++--- llmsql/utils/evaluation_utils.py | 5 ++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/llmsql/evaluation/evaluate.py b/llmsql/evaluation/evaluate.py index c07eb93..2500b0b 100644 --- a/llmsql/evaluation/evaluate.py +++ b/llmsql/evaluation/evaluate.py @@ -14,7 +14,7 @@ from rich.progress import track -from llmsql.config.config import DEFAULT_WORKDIR_PATH +from llmsql.config.config import DEFAULT_WORKDIR_PATH, DEFAULT_LLMSQL_VERSION, get_repo_id from llmsql.utils.evaluation_utils import ( connect_sqlite, download_benchmark_file, @@ -27,6 +27,7 @@ def evaluate( outputs: str | list[dict[int, str | int]], *, + version: str = DEFAULT_LLMSQL_VERSION, workdir_path: str | None = DEFAULT_WORKDIR_PATH, questions_path: str | None = None, db_path: str | None = None, @@ -38,6 +39,7 @@ def evaluate( Evaluate predicted SQL queries against the LLMSQL benchmark. Args: + version: LLMSQL version outputs: Either a JSONL file path or a list of dicts. workdir_path: Directory for auto-downloads (ignored if all paths provided). questions_path: Manual path to benchmark questions JSONL. @@ -53,6 +55,8 @@ def evaluate( # Determine input type input_mode = "jsonl_path" if isinstance(outputs, str) else "dict_list" + repo_id = get_repo_id(version) + # --- Resolve inputs if needed --- workdir = Path(workdir_path) if workdir_path else None if workdir_path is not None and (questions_path is None or db_path is None): @@ -68,7 +72,7 @@ def evaluate( questions_path = ( str(local_q) if local_q.is_file() - else download_benchmark_file("questions.jsonl", workdir) + else download_benchmark_file(repo_id, "questions.jsonl", workdir) ) if db_path is None: @@ -81,7 +85,7 @@ def evaluate( db_path = ( str(local_db) if local_db.is_file() - else download_benchmark_file("sqlite_tables.db", workdir) + else download_benchmark_file(repo_id, "sqlite_tables.db", workdir) ) # --- Load benchmark questions --- diff --git a/llmsql/utils/evaluation_utils.py b/llmsql/utils/evaluation_utils.py index be394fd..de86815 100644 --- a/llmsql/utils/evaluation_utils.py +++ b/llmsql/utils/evaluation_utils.py @@ -5,7 +5,6 @@ from huggingface_hub import hf_hub_download -from llmsql.config.config import DEFAULT_REPO_ID from llmsql.loggers.logging_config import log from llmsql.utils.regex_extractor import find_sql @@ -171,9 +170,9 @@ def evaluate_sample( def download_benchmark_file( + repo_id: str, filename: str, - local_dir: Path, - repo_id: str = DEFAULT_REPO_ID, + local_dir: Path ) -> str: """Download a benchmark file from HuggingFace Hub.""" file_path = hf_hub_download( From 120de0204dd3d863573ea02cd82cbdd478aa7026 Mon Sep 17 00:00:00 2001 From: Karol Charchut <59970980+Quarol@users.noreply.github.com> Date: Sun, 15 Feb 2026 23:54:06 +0100 Subject: [PATCH 03/16] adjusted function signature in conftest.py --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 58b585e..9125681 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -104,7 +104,7 @@ def mock_utils(mocker, tmp_path): # download files mocker.patch( "llmsql.evaluation.evaluate.download_benchmark_file", - side_effect=lambda filename, wd: str(Path(wd) / filename), + side_effect=lambda repo_id, filename, local_dir: str(Path(local_dir) / filename), ) # report writer From 041580ec442d1dd967b425789f5ee0284f97fe87 Mon Sep 17 00:00:00 2001 From: Karol Charchut <59970980+Quarol@users.noreply.github.com> Date: Mon, 16 Feb 2026 00:08:04 +0100 Subject: [PATCH 04/16] fixed signatures in tests --- llmsql/inference/inference_transformers.py | 8 ++++---- llmsql/inference/inference_vllm.py | 8 ++++---- llmsql/utils/inference_utils.py | 4 ++-- tests/inference/test_inference_stability.py | 2 +- tests/utils/test_inference_utils.py | 6 +++--- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/llmsql/inference/inference_transformers.py b/llmsql/inference/inference_transformers.py index ae2bc9a..1dd87c8 100644 --- a/llmsql/inference/inference_transformers.py +++ b/llmsql/inference/inference_transformers.py @@ -214,14 +214,14 @@ def inference_transformers( repo_id = get_repo_id(version) questions_path = _maybe_download( - "questions.jsonl", - questions_path, repo_id, + "questions.jsonl", + questions_path ) tables_path = _maybe_download( - "tables.jsonl", - tables_path, repo_id, + "tables.jsonl", + tables_path ) questions = load_jsonl(questions_path) diff --git a/llmsql/inference/inference_vllm.py b/llmsql/inference/inference_vllm.py index d42e413..daf677a 100644 --- a/llmsql/inference/inference_vllm.py +++ b/llmsql/inference/inference_vllm.py @@ -136,14 +136,14 @@ def inference_vllm( repo_id = get_repo_id(version) questions_path = _maybe_download( - "questions.jsonl", - questions_path, repo_id, + "questions.jsonl", + questions_path ) tables_path = _maybe_download( - "tables.jsonl", - tables_path, repo_id, + "tables.jsonl", + tables_path ) questions = load_jsonl(questions_path) diff --git a/llmsql/utils/inference_utils.py b/llmsql/utils/inference_utils.py index 7130aa7..da459f0 100644 --- a/llmsql/utils/inference_utils.py +++ b/llmsql/utils/inference_utils.py @@ -30,9 +30,9 @@ def _setup_seed(seed: int) -> None: def _maybe_download( + repo_id: str, filename: str, - local_path: str | None, - repo_id: str + local_path: str | None ) -> str: if local_path is not None: return local_path diff --git a/tests/inference/test_inference_stability.py b/tests/inference/test_inference_stability.py index cfa2a06..be721ea 100644 --- a/tests/inference/test_inference_stability.py +++ b/tests/inference/test_inference_stability.py @@ -81,7 +81,7 @@ async def test_inference_vllm_download_if_missing(monkeypatch, tmp_path): called = {"q": 0, "t": 0} - def fake_download(filename, path, **_): + def fake_download(repo_id, filename, path, **_): called["q" if "questions" in filename else "t"] += 1 path = tmp_path / filename # Write minimal JSONL for subsequent load_jsonl diff --git a/tests/utils/test_inference_utils.py b/tests/utils/test_inference_utils.py index afbe5ef..3d8faf2 100644 --- a/tests/utils/test_inference_utils.py +++ b/tests/utils/test_inference_utils.py @@ -6,7 +6,7 @@ import pytest import torch -from llmsql.config.config import DEFAULT_WORKDIR_PATH, get_repo_id +from llmsql.config.config import DEFAULT_WORKDIR_PATH, get_repo_id, DEFAULT_LLMSQL_VERSION from llmsql.utils import inference_utils as mod @@ -16,7 +16,7 @@ async def test_download_file(monkeypatch, tmp_path): expected_path = str(tmp_path / "questions.jsonl") def fake_hf_hub_download(repo_id, filename, repo_type, local_dir): - assert repo_id == get_repo_id() + assert repo_id == get_repo_id(DEFAULT_LLMSQL_VERSION) assert repo_type == "dataset" assert local_dir == DEFAULT_WORKDIR_PATH assert filename == "questions.jsonl" @@ -80,5 +80,5 @@ def fake_hf_hub_download(**kwargs): path = mod._maybe_download(filename, local_path=None) assert Path(path).exists() - assert called["repo_id"] == get_repo_id() + assert called["repo_id"] == get_repo_id(DEFAULT_LLMSQL_VERSION) assert called["filename"] == filename From 24b924c7b321a75689026ab3b84e5eb535a66c77 Mon Sep 17 00:00:00 2001 From: Karol Charchut <59970980+Quarol@users.noreply.github.com> Date: Mon, 16 Feb 2026 19:45:40 +0100 Subject: [PATCH 05/16] test change --- llmsql/utils/inference_utils.py | 2 +- tests/utils/test_inference_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llmsql/utils/inference_utils.py b/llmsql/utils/inference_utils.py index da459f0..4189aad 100644 --- a/llmsql/utils/inference_utils.py +++ b/llmsql/utils/inference_utils.py @@ -10,7 +10,7 @@ # --- Load benchmark data --- -def _download_file(filename: str, repo_id: str) -> str: +def _download_file(repo_id: str, filename: str) -> str: path = hf_hub_download( repo_id=repo_id, filename=filename, diff --git a/tests/utils/test_inference_utils.py b/tests/utils/test_inference_utils.py index 3d8faf2..114b86e 100644 --- a/tests/utils/test_inference_utils.py +++ b/tests/utils/test_inference_utils.py @@ -23,7 +23,7 @@ def fake_hf_hub_download(repo_id, filename, repo_type, local_dir): return expected_path monkeypatch.setattr(mod, "hf_hub_download", fake_hf_hub_download) - path = mod._download_file("questions.jsonl") + path = mod._download_file(get_repo_id(DEFAULT_LLMSQL_VERSION), "questions.jsonl") assert path == expected_path From 9a966ee26800ecfa2f5e51ce6b0590af73c0b4e9 Mon Sep 17 00:00:00 2001 From: Karol Charchut <59970980+Quarol@users.noreply.github.com> Date: Mon, 16 Feb 2026 19:54:13 +0100 Subject: [PATCH 06/16] fixing _maybe_download arguments mismatch in tests. --- tests/utils/test_inference_utils.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/utils/test_inference_utils.py b/tests/utils/test_inference_utils.py index 114b86e..4fd7a0e 100644 --- a/tests/utils/test_inference_utils.py +++ b/tests/utils/test_inference_utils.py @@ -52,13 +52,21 @@ async def test_maybe_download_existing_file(tmp_path, monkeypatch): existing.write_text("dummy") monkeypatch.setattr(mod, "hf_hub_download", lambda *a, **kw: "FAIL") # Should return local path directly - path = mod._maybe_download("questions.jsonl", local_path=str(existing)) + path = mod._maybe_download( + get_repo_id(DEFAULT_LLMSQL_VERSION), + "questions.jsonl", + local_path=str(existing) + ) assert path == str(existing) # Should also return target_path if file exists in DEFAULT_WORKDIR_PATH monkeypatch.setattr(mod, "hf_hub_download", lambda *a, **kw: "FAIL") monkeypatch.setattr(mod, "DEFAULT_WORKDIR_PATH", str(tmp_path)) - path2 = mod._maybe_download("questions.jsonl", local_path=None) + path2 = mod._maybe_download( + get_repo_id(DEFAULT_LLMSQL_VERSION), + "questions.jsonl", + local_path=None + ) assert Path(path2).exists() or path2.endswith("questions.jsonl") @@ -78,7 +86,11 @@ def fake_hf_hub_download(**kwargs): monkeypatch.setattr(mod, "hf_hub_download", fake_hf_hub_download) - path = mod._maybe_download(filename, local_path=None) + path = mod._maybe_download( + get_repo_id(DEFAULT_LLMSQL_VERSION), + filename, + local_path=None + ) assert Path(path).exists() assert called["repo_id"] == get_repo_id(DEFAULT_LLMSQL_VERSION) assert called["filename"] == filename From 45539757b838c9967c215320d7b12af60d02bf82 Mon Sep 17 00:00:00 2001 From: Karol Charchut <59970980+Quarol@users.noreply.github.com> Date: Thu, 19 Feb 2026 22:50:39 +0100 Subject: [PATCH 07/16] Add llmsql versions test on inference_transformers --- ...ence_transformers_on_different_versions.py | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 tests/inference/test_inference_transformers_on_different_versions.py diff --git a/tests/inference/test_inference_transformers_on_different_versions.py b/tests/inference/test_inference_transformers_on_different_versions.py new file mode 100644 index 0000000..af457a9 --- /dev/null +++ b/tests/inference/test_inference_transformers_on_different_versions.py @@ -0,0 +1,94 @@ +import json +from pathlib import Path +import tempfile + +import pytest + +from llmsql.inference.inference_transformers import inference_transformers + +# --- Minimal fake benchmark data for testing --- +questions = [ + {"question_id": "q1", "table_id": "t1", "question": "Select name from students;"}, + {"question_id": "q2", "table_id": "t1", "question": "Count students older than 20;"}, +] +tables = [ + { + "table_id": "t1", + "header": ["id", "name", "age"], + "types": ["int", "str", "int"], + "rows": [[1, "Alice", 21], [2, "Bob", 19]], + } +] + + +# Save minimal JSONL files for testing +def _write_jsonl(data, path: Path): + with path.open("w", encoding="utf-8") as f: + for row in data: + f.write(json.dumps(row) + "\n") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("version_arg", [None, "2.0", "1.0"]) +async def test_inference_stability_on_valid_version_flags(version_arg): + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir_path = Path(tmpdir) + questions_file = tmpdir_path / "questions.jsonl" + tables_file = tmpdir_path / "tables.jsonl" + output_file = tmpdir_path / "outputs.jsonl" + + _write_jsonl(questions, questions_file) + _write_jsonl(tables, tables_file) + + kwargs = { + "model_or_model_name_or_path": "sshleifer/tiny-gpt2", + "tokenizer_or_name": "sshleifer/tiny-gpt2", + "output_file": str(output_file), + "questions_path": str(questions_file), + "tables_path": str(tables_file), + "batch_size": 1, + "max_new_tokens": 8, + "temperature": 0.0, + "do_sample": False, + } + + if version_arg is not None: + kwargs["version"] = version_arg + + results = inference_transformers(**kwargs) + + assert isinstance(results, list) + assert all("question_id" in r and "completion" in r for r in results) + assert output_file.exists() + + if version_arg is not None: + for r in results: + assert "completion" in r + + +@pytest.mark.asyncio +async def test_inference_stability_on_invalid_version_flag(): + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir_path = Path(tmpdir) + q_file = tmpdir_path / "questions.jsonl" + t_file = tmpdir_path / "tables.jsonl" + out_file = tmpdir_path / "outputs.jsonl" + + _write_jsonl(questions, q_file) + _write_jsonl(tables, t_file) + + kwargs = { + "model_or_model_name_or_path": "sshleifer/tiny-gpt2", + "tokenizer_or_name": "sshleifer/tiny-gpt2", + "output_file": str(out_file), + "questions_path": str(q_file), + "tables_path": str(t_file), + "batch_size": 1, + "max_new_tokens": 8, + "temperature": 0.0, + "do_sample": False, + "version": "1.1", # invalid version + } + + with pytest.raises(Exception): + inference_transformers(**kwargs) From 6787db4b1098b0819a75fa685721c0b9b4a1c192 Mon Sep 17 00:00:00 2001 From: Karol Charchut <59970980+Quarol@users.noreply.github.com> Date: Thu, 19 Feb 2026 22:54:23 +0100 Subject: [PATCH 08/16] fixed hardcoded LLMSQL in the test. --- .../test_inference_transformers_on_different_versions.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/inference/test_inference_transformers_on_different_versions.py b/tests/inference/test_inference_transformers_on_different_versions.py index af457a9..80541ae 100644 --- a/tests/inference/test_inference_transformers_on_different_versions.py +++ b/tests/inference/test_inference_transformers_on_different_versions.py @@ -4,6 +4,7 @@ import pytest +from llmsql.config.config import get_available_versions from llmsql.inference.inference_transformers import inference_transformers # --- Minimal fake benchmark data for testing --- @@ -20,6 +21,9 @@ } ] +VALID_LLMSQL_VERSIONS = [None] + get_available_versions() +INVALID_LLMSQL_VERSION = "1.1" + # Save minimal JSONL files for testing def _write_jsonl(data, path: Path): @@ -29,7 +33,7 @@ def _write_jsonl(data, path: Path): @pytest.mark.asyncio -@pytest.mark.parametrize("version_arg", [None, "2.0", "1.0"]) +@pytest.mark.parametrize("version_arg", VALID_LLMSQL_VERSIONS) async def test_inference_stability_on_valid_version_flags(version_arg): with tempfile.TemporaryDirectory() as tmpdir: tmpdir_path = Path(tmpdir) @@ -87,7 +91,7 @@ async def test_inference_stability_on_invalid_version_flag(): "max_new_tokens": 8, "temperature": 0.0, "do_sample": False, - "version": "1.1", # invalid version + "version": INVALID_LLMSQL_VERSION, # invalid version } with pytest.raises(Exception): From e5b8bc380f59eee1a0ec226bd62055781d9bfeb6 Mon Sep 17 00:00:00 2001 From: Karol Charchut <59970980+Quarol@users.noreply.github.com> Date: Thu, 19 Feb 2026 23:06:13 +0100 Subject: [PATCH 09/16] added evaluation and vllm tests for different --version flag values (valid ones and an invalid one) --- ...est_evaluator_different_llmsql_versions.py | 55 ++++++++++ ...transformers_different_llmsql_versions.py} | 1 - ...nference_vllm_different_llmsql_versions.py | 100 ++++++++++++++++++ 3 files changed, 155 insertions(+), 1 deletion(-) create mode 100644 tests/evaluation/test_evaluator_different_llmsql_versions.py rename tests/inference/{test_inference_transformers_on_different_versions.py => test_inference_transformers_different_llmsql_versions.py} (98%) create mode 100644 tests/inference/test_inference_vllm_different_llmsql_versions.py diff --git a/tests/evaluation/test_evaluator_different_llmsql_versions.py b/tests/evaluation/test_evaluator_different_llmsql_versions.py new file mode 100644 index 0000000..4fd0b19 --- /dev/null +++ b/tests/evaluation/test_evaluator_different_llmsql_versions.py @@ -0,0 +1,55 @@ +import pytest +from llmsql import evaluate +from llmsql.config.config import get_available_versions + +VALID_LLMSQL_VERSIONS = [None] + get_available_versions() +INVALID_LLMSQL_VERSION = "1.1" + + +@pytest.mark.parametrize("version_arg", VALID_LLMSQL_VERSIONS) +def test_evaluate_runs_with_valid_versions(monkeypatch, tmp_path, version_arg): + outputs_path = tmp_path / "outputs.jsonl" + outputs_path.write_text('{"question_id":1,"completion":"SELECT 1"}') + + questions_path = tmp_path / "questions.jsonl" + questions_path.write_text('{"question_id":1,"table_id":1,"question":"x","sql":"SELECT 1"}') + + monkeypatch.setattr("llmsql.utils.evaluation_utils.evaluate_sample", lambda *a, **k: (1, None, {})) + monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None) + monkeypatch.setattr("llmsql.utils.rich_utils.print_summary", lambda *a, **k: None) + + kwargs = { + "outputs": str(outputs_path), + "questions_path": str(questions_path), + "db_path": tmp_path / "dummy.db", + "show_mismatches": False, + } + + if version_arg is not None: + kwargs["version"] = version_arg + + # Should NOT raise + evaluate(**kwargs) + + +def test_evaluate_raises_with_invalid_version(monkeypatch, tmp_path): + outputs_path = tmp_path / "outputs.jsonl" + outputs_path.write_text('{"question_id":1,"completion":"SELECT 1"}') + + questions_path = tmp_path / "questions.jsonl" + questions_path.write_text('{"question_id":1,"table_id":1,"question":"x","sql":"SELECT 1"}') + + monkeypatch.setattr("llmsql.utils.evaluation_utils.evaluate_sample", lambda *a, **k: (1, None, {})) + monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None) + monkeypatch.setattr("llmsql.utils.rich_utils.print_summary", lambda *a, **k: None) + + kwargs = { + "outputs": str(outputs_path), + "questions_path": str(questions_path), + "db_path": tmp_path / "dummy.db", + "show_mismatches": False, + "version": INVALID_LLMSQL_VERSION, + } + + with pytest.raises(Exception): + evaluate(**kwargs) \ No newline at end of file diff --git a/tests/inference/test_inference_transformers_on_different_versions.py b/tests/inference/test_inference_transformers_different_llmsql_versions.py similarity index 98% rename from tests/inference/test_inference_transformers_on_different_versions.py rename to tests/inference/test_inference_transformers_different_llmsql_versions.py index 80541ae..cb1dec5 100644 --- a/tests/inference/test_inference_transformers_on_different_versions.py +++ b/tests/inference/test_inference_transformers_different_llmsql_versions.py @@ -7,7 +7,6 @@ from llmsql.config.config import get_available_versions from llmsql.inference.inference_transformers import inference_transformers -# --- Minimal fake benchmark data for testing --- questions = [ {"question_id": "q1", "table_id": "t1", "question": "Select name from students;"}, {"question_id": "q2", "table_id": "t1", "question": "Count students older than 20;"}, diff --git a/tests/inference/test_inference_vllm_different_llmsql_versions.py b/tests/inference/test_inference_vllm_different_llmsql_versions.py new file mode 100644 index 0000000..78615c2 --- /dev/null +++ b/tests/inference/test_inference_vllm_different_llmsql_versions.py @@ -0,0 +1,100 @@ +import json +from pathlib import Path +from unittest.mock import MagicMock +import pytest + +import llmsql.inference.inference_vllm as mod +from llmsql.config.config import get_available_versions + +questions = [ + {"question_id": "q1", "table_id": "t1", "question": "Select name from students;"}, + {"question_id": "q2", "table_id": "t1", "question": "Count students older than 20;"}, +] +tables = [ + { + "table_id": "t1", + "header": ["id", "name", "age"], + "types": ["int", "str", "int"], + "rows": [[1, "Alice", 21], [2, "Bob", 19]], + } +] + +VALID_LLMSQL_VERSIONS = [None] + get_available_versions() +INVALID_LLMSQL_VERSION = "1.1" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("version_arg", VALID_LLMSQL_VERSIONS) +async def test_inference_vllm_valid_versions(monkeypatch, tmp_path, version_arg): + """Test inference_vllm with valid version flags using local JSONL files.""" + q_file = tmp_path / "questions.jsonl" + t_file = tmp_path / "tables.jsonl" + out_file = tmp_path / "out.jsonl" + + q_file.write_text("\n".join(json.dumps(q) for q in questions)) + t_file.write_text("\n".join(json.dumps(t) for t in tables)) + + monkeypatch.setattr(mod, "load_jsonl", + lambda path: [json.loads(line) for line in Path(path).read_text().splitlines()]) + monkeypatch.setattr(mod, "overwrite_jsonl", lambda path: None) + monkeypatch.setattr(mod, "save_jsonl_lines", lambda path, lines: None) + monkeypatch.setattr(mod, "choose_prompt_builder", lambda shots: lambda *a: "PROMPT") + + fake_llm = MagicMock() + fake_llm.generate.return_value = [MagicMock(outputs=[MagicMock(text="SELECT 1")])] + monkeypatch.setattr(mod, "LLM", lambda *a, **kw: fake_llm) + + kwargs = { + "model_name": "dummy-model", + "output_file": str(out_file), + "questions_path": str(q_file), + "tables_path": str(t_file), + "num_fewshots": 1, + "batch_size": 1, + "max_new_tokens": 8, + "temperature": 0.0, + } + if version_arg is not None: + kwargs["version"] = version_arg + + results = mod.inference_vllm(**kwargs) + + assert isinstance(results, list) + assert all("question_id" in r and "completion" in r for r in results) + assert out_file.exists() + + +@pytest.mark.asyncio +async def test_inference_vllm_invalid_version(monkeypatch, tmp_path): + """Test inference_vllm raises exception with invalid version flag.""" + q_file = tmp_path / "questions.jsonl" + t_file = tmp_path / "tables.jsonl" + out_file = tmp_path / "out.jsonl" + + q_file.write_text("\n".join(json.dumps(q) for q in questions)) + t_file.write_text("\n".join(json.dumps(t) for t in tables)) + + monkeypatch.setattr(mod, "load_jsonl", + lambda path: [json.loads(line) for line in Path(path).read_text().splitlines()]) + monkeypatch.setattr(mod, "overwrite_jsonl", lambda path: None) + monkeypatch.setattr(mod, "save_jsonl_lines", lambda path, lines: None) + monkeypatch.setattr(mod, "choose_prompt_builder", lambda shots: lambda *a: "PROMPT") + + fake_llm = MagicMock() + fake_llm.generate.return_value = [MagicMock(outputs=[MagicMock(text="SELECT 1")])] + monkeypatch.setattr(mod, "LLM", lambda *a, **kw: fake_llm) + + kwargs = { + "model_name": "dummy-model", + "output_file": str(out_file), + "questions_path": str(q_file), + "tables_path": str(t_file), + "num_fewshots": 1, + "batch_size": 1, + "max_new_tokens": 8, + "temperature": 0.0, + "version": INVALID_LLMSQL_VERSION, # invalid version + } + + with pytest.raises(Exception): + mod.inference_vllm(**kwargs) \ No newline at end of file From 4a1fd4691813a8a901da7783a48cd944f2684ef2 Mon Sep 17 00:00:00 2001 From: Karol Charchut <59970980+Quarol@users.noreply.github.com> Date: Fri, 20 Feb 2026 16:50:24 +0100 Subject: [PATCH 10/16] fixed evaluation test. --- ...est_evaluator_different_llmsql_versions.py | 72 +++++++++++++------ 1 file changed, 52 insertions(+), 20 deletions(-) diff --git a/tests/evaluation/test_evaluator_different_llmsql_versions.py b/tests/evaluation/test_evaluator_different_llmsql_versions.py index 4fd0b19..66a4119 100644 --- a/tests/evaluation/test_evaluator_different_llmsql_versions.py +++ b/tests/evaluation/test_evaluator_different_llmsql_versions.py @@ -1,27 +1,47 @@ +import sqlite3 import pytest + from llmsql import evaluate from llmsql.config.config import get_available_versions + VALID_LLMSQL_VERSIONS = [None] + get_available_versions() -INVALID_LLMSQL_VERSION = "1.1" +INVALID_LLMSQL_VERSION = "999.0" @pytest.mark.parametrize("version_arg", VALID_LLMSQL_VERSIONS) def test_evaluate_runs_with_valid_versions(monkeypatch, tmp_path, version_arg): + # --- Minimal fake outputs and questions --- outputs_path = tmp_path / "outputs.jsonl" - outputs_path.write_text('{"question_id":1,"completion":"SELECT 1"}') + outputs_path.write_text('{"question_id":1,"completion":"SELECT 1"}\n') questions_path = tmp_path / "questions.jsonl" - questions_path.write_text('{"question_id":1,"table_id":1,"question":"x","sql":"SELECT 1"}') + questions_path.write_text( + '{"question_id":1,"table_id":1,"question":"x","sql":"SELECT 1"}\n' + ) - monkeypatch.setattr("llmsql.utils.evaluation_utils.evaluate_sample", lambda *a, **k: (1, None, {})) - monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None) - monkeypatch.setattr("llmsql.utils.rich_utils.print_summary", lambda *a, **k: None) + # --- Create a real (empty) SQLite database file --- + db_path = tmp_path / "dummy.db" + sqlite3.connect(db_path).close() + + # --- Patch heavy evaluation internals --- + monkeypatch.setattr( + "llmsql.evaluation.evaluate.evaluate_sample", + lambda *a, **k: (1, None, {}) + ) + monkeypatch.setattr( + "llmsql.evaluation.evaluate.log_mismatch", + lambda **k: None + ) + monkeypatch.setattr( + "llmsql.evaluation.evaluate.print_summary", + lambda *a, **k: None + ) kwargs = { "outputs": str(outputs_path), "questions_path": str(questions_path), - "db_path": tmp_path / "dummy.db", + "db_path": str(db_path), "show_mismatches": False, } @@ -34,22 +54,34 @@ def test_evaluate_runs_with_valid_versions(monkeypatch, tmp_path, version_arg): def test_evaluate_raises_with_invalid_version(monkeypatch, tmp_path): outputs_path = tmp_path / "outputs.jsonl" - outputs_path.write_text('{"question_id":1,"completion":"SELECT 1"}') + outputs_path.write_text('{"question_id":1,"completion":"SELECT 1"}\n') questions_path = tmp_path / "questions.jsonl" - questions_path.write_text('{"question_id":1,"table_id":1,"question":"x","sql":"SELECT 1"}') + questions_path.write_text( + '{"question_id":1,"table_id":1,"question":"x","sql":"SELECT 1"}\n' + ) - monkeypatch.setattr("llmsql.utils.evaluation_utils.evaluate_sample", lambda *a, **k: (1, None, {})) - monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None) - monkeypatch.setattr("llmsql.utils.rich_utils.print_summary", lambda *a, **k: None) + db_path = tmp_path / "dummy.db" + sqlite3.connect(db_path).close() - kwargs = { - "outputs": str(outputs_path), - "questions_path": str(questions_path), - "db_path": tmp_path / "dummy.db", - "show_mismatches": False, - "version": INVALID_LLMSQL_VERSION, - } + monkeypatch.setattr( + "llmsql.evaluation.evaluate.evaluate_sample", + lambda *a, **k: (1, None, {}) + ) + monkeypatch.setattr( + "llmsql.evaluation.evaluate.log_mismatch", + lambda **k: None + ) + monkeypatch.setattr( + "llmsql.evaluation.evaluate.print_summary", + lambda *a, **k: None + ) with pytest.raises(Exception): - evaluate(**kwargs) \ No newline at end of file + evaluate( + outputs=str(outputs_path), + questions_path=str(questions_path), + db_path=str(db_path), + show_mismatches=False, + version=INVALID_LLMSQL_VERSION, + ) \ No newline at end of file From 8aca67c92a7b1e147bcbb7b7223b65a01f5f9af9 Mon Sep 17 00:00:00 2001 From: Karol Charchut <59970980+Quarol@users.noreply.github.com> Date: Fri, 20 Feb 2026 16:57:36 +0100 Subject: [PATCH 11/16] fixed --- .../test_evaluator_different_llmsql_versions.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/evaluation/test_evaluator_different_llmsql_versions.py b/tests/evaluation/test_evaluator_different_llmsql_versions.py index 66a4119..e57a355 100644 --- a/tests/evaluation/test_evaluator_different_llmsql_versions.py +++ b/tests/evaluation/test_evaluator_different_llmsql_versions.py @@ -26,9 +26,17 @@ def test_evaluate_runs_with_valid_versions(monkeypatch, tmp_path, version_arg): # --- Patch heavy evaluation internals --- monkeypatch.setattr( - "llmsql.evaluation.evaluate.evaluate_sample", - lambda *a, **k: (1, None, {}) - ) + "llmsql.evaluation.evaluate.evaluate_sample", + lambda *a, **k: ( + 1, + None, + { + "pred_none": 0, + "gold_none": 0, + "sql_errors": 0, + }, + ), +) monkeypatch.setattr( "llmsql.evaluation.evaluate.log_mismatch", lambda **k: None From 8b6b333f98f07da0a0c642994761bcabb2c30dcb Mon Sep 17 00:00:00 2001 From: Karol Charchut <59970980+Quarol@users.noreply.github.com> Date: Fri, 20 Feb 2026 17:06:00 +0100 Subject: [PATCH 12/16] another fixing attempt --- ...est_evaluator_different_llmsql_versions.py | 112 ++++++++++-------- 1 file changed, 61 insertions(+), 51 deletions(-) diff --git a/tests/evaluation/test_evaluator_different_llmsql_versions.py b/tests/evaluation/test_evaluator_different_llmsql_versions.py index e57a355..a10fe50 100644 --- a/tests/evaluation/test_evaluator_different_llmsql_versions.py +++ b/tests/evaluation/test_evaluator_different_llmsql_versions.py @@ -1,4 +1,4 @@ -import sqlite3 +import json import pytest from llmsql import evaluate @@ -6,50 +6,49 @@ VALID_LLMSQL_VERSIONS = [None] + get_available_versions() -INVALID_LLMSQL_VERSION = "999.0" +INVALID_LLMSQL_VERSION = "1.1" +@pytest.mark.asyncio @pytest.mark.parametrize("version_arg", VALID_LLMSQL_VERSIONS) -def test_evaluate_runs_with_valid_versions(monkeypatch, tmp_path, version_arg): - # --- Minimal fake outputs and questions --- - outputs_path = tmp_path / "outputs.jsonl" - outputs_path.write_text('{"question_id":1,"completion":"SELECT 1"}\n') - - questions_path = tmp_path / "questions.jsonl" +async def test_evaluate_runs_with_valid_versions( + monkeypatch, temp_dir, dummy_db_file, version_arg +): + # Fake questions.jsonl + questions_path = temp_dir / "questions.jsonl" questions_path.write_text( - '{"question_id":1,"table_id":1,"question":"x","sql":"SELECT 1"}\n' + json.dumps( + { + "question_id": 1, + "table_id": 1, + "question": "Sample", + "sql": "SELECT 1", + } + ) ) - # --- Create a real (empty) SQLite database file --- - db_path = tmp_path / "dummy.db" - sqlite3.connect(db_path).close() - - # --- Patch heavy evaluation internals --- - monkeypatch.setattr( - "llmsql.evaluation.evaluate.evaluate_sample", - lambda *a, **k: ( - 1, - None, - { - "pred_none": 0, - "gold_none": 0, - "sql_errors": 0, - }, - ), -) - monkeypatch.setattr( - "llmsql.evaluation.evaluate.log_mismatch", - lambda **k: None + # Fake outputs.jsonl + outputs_path = temp_dir / "outputs.jsonl" + outputs_path.write_text( + json.dumps({"question_id": 1, "completion": "SELECT 1"}) ) + + # Monkeypatch exactly like reference file monkeypatch.setattr( - "llmsql.evaluation.evaluate.print_summary", - lambda *a, **k: None + "llmsql.utils.evaluation_utils.evaluate_sample", + lambda *a, **k: ( + 1, + None, + {"pred_none": 0, "gold_none": 0, "sql_error": 0}, + ), ) + monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None) + monkeypatch.setattr("llmsql.utils.rich_utils.print_summary", lambda *a, **k: None) kwargs = { "outputs": str(outputs_path), "questions_path": str(questions_path), - "db_path": str(db_path), + "db_path": dummy_db_file, "show_mismatches": False, } @@ -57,39 +56,50 @@ def test_evaluate_runs_with_valid_versions(monkeypatch, tmp_path, version_arg): kwargs["version"] = version_arg # Should NOT raise - evaluate(**kwargs) + report = evaluate(**kwargs) + # Basic sanity like reference tests + assert report["total"] == 1 + assert report["matches"] == 1 -def test_evaluate_raises_with_invalid_version(monkeypatch, tmp_path): - outputs_path = tmp_path / "outputs.jsonl" - outputs_path.write_text('{"question_id":1,"completion":"SELECT 1"}\n') - questions_path = tmp_path / "questions.jsonl" +@pytest.mark.asyncio +async def test_evaluate_raises_with_invalid_version( + monkeypatch, temp_dir, dummy_db_file +): + questions_path = temp_dir / "questions.jsonl" questions_path.write_text( - '{"question_id":1,"table_id":1,"question":"x","sql":"SELECT 1"}\n' + json.dumps( + { + "question_id": 1, + "table_id": 1, + "question": "Sample", + "sql": "SELECT 1", + } + ) ) - db_path = tmp_path / "dummy.db" - sqlite3.connect(db_path).close() - - monkeypatch.setattr( - "llmsql.evaluation.evaluate.evaluate_sample", - lambda *a, **k: (1, None, {}) - ) - monkeypatch.setattr( - "llmsql.evaluation.evaluate.log_mismatch", - lambda **k: None + outputs_path = temp_dir / "outputs.jsonl" + outputs_path.write_text( + json.dumps({"question_id": 1, "completion": "SELECT 1"}) ) + monkeypatch.setattr( - "llmsql.evaluation.evaluate.print_summary", - lambda *a, **k: None + "llmsql.utils.evaluation_utils.evaluate_sample", + lambda *a, **k: ( + 1, + None, + {"pred_none": 0, "gold_none": 0, "sql_error": 0}, + ), ) + monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None) + monkeypatch.setattr("llmsql.utils.rich_utils.print_summary", lambda *a, **k: None) with pytest.raises(Exception): evaluate( outputs=str(outputs_path), questions_path=str(questions_path), - db_path=str(db_path), + db_path=dummy_db_file, show_mismatches=False, version=INVALID_LLMSQL_VERSION, ) \ No newline at end of file From 6adaf73fc0d3987609d62d324c40a5b6759e73ce Mon Sep 17 00:00:00 2001 From: Karol Charchut <59970980+Quarol@users.noreply.github.com> Date: Fri, 20 Feb 2026 17:15:40 +0100 Subject: [PATCH 13/16] removed redudant assertions --- .../evaluation/test_evaluator_different_llmsql_versions.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/evaluation/test_evaluator_different_llmsql_versions.py b/tests/evaluation/test_evaluator_different_llmsql_versions.py index a10fe50..8fce11a 100644 --- a/tests/evaluation/test_evaluator_different_llmsql_versions.py +++ b/tests/evaluation/test_evaluator_different_llmsql_versions.py @@ -55,12 +55,8 @@ async def test_evaluate_runs_with_valid_versions( if version_arg is not None: kwargs["version"] = version_arg - # Should NOT raise - report = evaluate(**kwargs) + evaluate(**kwargs) - # Basic sanity like reference tests - assert report["total"] == 1 - assert report["matches"] == 1 @pytest.mark.asyncio From 41c7972f7b084f720deba7bf515c9e29c2b0c862 Mon Sep 17 00:00:00 2001 From: Dzmitry Pihulski Date: Mon, 23 Feb 2026 08:28:28 +0100 Subject: [PATCH 14/16] fix: monkeypatch path updated; --- .gitignore | 1 + llmsql/config/config.py | 15 ++++++-- ...nference_vllm_different_llmsql_versions.py | 37 ++++++++++++++----- 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index 6b09728..a3c3066 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ dist/ llmsql_workdir evaluation_* +coverage.xml diff --git a/llmsql/config/config.py b/llmsql/config/config.py index d8f9e36..9ba9a1c 100644 --- a/llmsql/config/config.py +++ b/llmsql/config/config.py @@ -1,12 +1,19 @@ -REPO_IDs: dict = { +REPO_IDs: dict[str, str] = { "1.0": "llmsql-bench/llmsql-benchmark", - "2.0": "llmsql-bench/llmsql-2.0" + "2.0": "llmsql-bench/llmsql-2.0", } DEFAULT_LLMSQL_VERSION = "2.0" DEFAULT_WORKDIR_PATH = "llmsql_workdir" + def get_repo_id(version: str = DEFAULT_LLMSQL_VERSION) -> str: - return REPO_IDs[version] + try: + return REPO_IDs[version] + except KeyError as err: + raise ValueError( + f"version should be one of: {list(REPO_IDs.keys())}, not {version}" + ) from err + def get_available_versions() -> list[str]: - return list(REPO_IDs.keys()) \ No newline at end of file + return list(REPO_IDs.keys()) diff --git a/tests/inference/test_inference_vllm_different_llmsql_versions.py b/tests/inference/test_inference_vllm_different_llmsql_versions.py index 78615c2..bc70af7 100644 --- a/tests/inference/test_inference_vllm_different_llmsql_versions.py +++ b/tests/inference/test_inference_vllm_different_llmsql_versions.py @@ -1,14 +1,20 @@ import json from pathlib import Path +import re from unittest.mock import MagicMock + import pytest +from llmsql.config.config import REPO_IDs, get_available_versions import llmsql.inference.inference_vllm as mod -from llmsql.config.config import get_available_versions questions = [ {"question_id": "q1", "table_id": "t1", "question": "Select name from students;"}, - {"question_id": "q2", "table_id": "t1", "question": "Count students older than 20;"}, + { + "question_id": "q2", + "table_id": "t1", + "question": "Count students older than 20;", + }, ] tables = [ { @@ -34,10 +40,13 @@ async def test_inference_vllm_valid_versions(monkeypatch, tmp_path, version_arg) q_file.write_text("\n".join(json.dumps(q) for q in questions)) t_file.write_text("\n".join(json.dumps(t) for t in tables)) - monkeypatch.setattr(mod, "load_jsonl", - lambda path: [json.loads(line) for line in Path(path).read_text().splitlines()]) - monkeypatch.setattr(mod, "overwrite_jsonl", lambda path: None) - monkeypatch.setattr(mod, "save_jsonl_lines", lambda path, lines: None) + monkeypatch.setattr( + mod, + "load_jsonl", + lambda path: [json.loads(line) for line in Path(path).read_text().splitlines()], + ) + monkeypatch.setattr(mod, "overwrite_jsonl", lambda path: Path(path).touch()) + monkeypatch.setattr(mod, "save_jsonl_lines", lambda path, lines: Path(path).touch()) monkeypatch.setattr(mod, "choose_prompt_builder", lambda shots: lambda *a: "PROMPT") fake_llm = MagicMock() @@ -74,8 +83,11 @@ async def test_inference_vllm_invalid_version(monkeypatch, tmp_path): q_file.write_text("\n".join(json.dumps(q) for q in questions)) t_file.write_text("\n".join(json.dumps(t) for t in tables)) - monkeypatch.setattr(mod, "load_jsonl", - lambda path: [json.loads(line) for line in Path(path).read_text().splitlines()]) + monkeypatch.setattr( + mod, + "load_jsonl", + lambda path: [json.loads(line) for line in Path(path).read_text().splitlines()], + ) monkeypatch.setattr(mod, "overwrite_jsonl", lambda path: None) monkeypatch.setattr(mod, "save_jsonl_lines", lambda path, lines: None) monkeypatch.setattr(mod, "choose_prompt_builder", lambda shots: lambda *a: "PROMPT") @@ -96,5 +108,10 @@ async def test_inference_vllm_invalid_version(monkeypatch, tmp_path): "version": INVALID_LLMSQL_VERSION, # invalid version } - with pytest.raises(Exception): - mod.inference_vllm(**kwargs) \ No newline at end of file + expected = ( + f"version should be one of: {list(REPO_IDs.keys())}, " + f"not {INVALID_LLMSQL_VERSION}" + ) + + with pytest.raises(ValueError, match=re.escape(expected)): + mod.inference_vllm(**kwargs) From ad970d091982c3a19c8c390b5c7a8364f4aa9601 Mon Sep 17 00:00:00 2001 From: Dzmitry Pihulski Date: Mon, 23 Feb 2026 10:51:03 +0100 Subject: [PATCH 15/16] add: limit flag --- examples/inference_transformers.ipynb | 2 +- .../inference_transformers_version_1.0.ipynb | 85 ++++ llmsql/config/config.py | 5 +- llmsql/inference/inference_transformers.py | 45 ++- llmsql/inference/inference_vllm.py | 47 ++- llmsql/utils/inference_utils.py | 34 +- pdm.lock | 372 +++++++++++++++++- pyproject.toml | 2 + tests/inference/test_limit_argument.py | 270 +++++++++++++ tests/utils/test_inference_utils.py | 35 +- 10 files changed, 813 insertions(+), 84 deletions(-) create mode 100644 examples/inference_transformers_version_1.0.ipynb create mode 100644 tests/inference/test_limit_argument.py diff --git a/examples/inference_transformers.ipynb b/examples/inference_transformers.ipynb index a59bfa7..43b6928 100644 --- a/examples/inference_transformers.ipynb +++ b/examples/inference_transformers.ipynb @@ -87,7 +87,7 @@ ], "metadata": { "kernelspec": { - "display_name": "llmsql-benchmark-3.11", + "display_name": "llmsql-benchmark-3.11 (3.11.13)", "language": "python", "name": "python3" }, diff --git a/examples/inference_transformers_version_1.0.ipynb b/examples/inference_transformers_version_1.0.ipynb new file mode 100644 index 0000000..c67a90b --- /dev/null +++ b/examples/inference_transformers_version_1.0.ipynb @@ -0,0 +1,85 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "e71979fc", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/pihul/Desktop/gdrive/Projects/llmsql-benchmark/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model from: EleutherAI/pythia-14m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`torch_dtype` is deprecated! Use `dtype` instead!\n", + "2026-02-23 09:00:00,424 [INFO] llmsql-bench: Removing existing path: llmsql_workdir/questions.jsonl\n", + "2026-02-23 09:00:00,426 [INFO] llmsql-bench: Downloading questions.jsonl from Hugging Face Hub...\n", + "2026-02-23 09:00:01,433 [INFO] llmsql-bench: Downloaded questions.jsonl to: llmsql_workdir/questions.jsonl\n", + "2026-02-23 09:00:01,434 [INFO] llmsql-bench: Removing existing path: llmsql_workdir/tables.jsonl\n", + "2026-02-23 09:00:01,436 [INFO] llmsql-bench: Downloading tables.jsonl from Hugging Face Hub...\n", + "2026-02-23 09:00:02,117 [INFO] llmsql-bench: Downloaded tables.jsonl to: llmsql_workdir/tables.jsonl\n", + "2026-02-23 09:00:03,373 [INFO] llmsql-bench: Limiting evaluation to first 100 questions out of 80330\n", + "2026-02-23 09:00:03,383 [INFO] llmsql-bench: Writing results to test_output.jsonl\n", + "2026-02-23 09:00:03,384 [INFO] llmsql-bench: Using 5-shot prompt builder: build_prompt_5shot\n", + "Generating: 0%| | 0/4 [00:00 list[dict[str, str]]: """ @@ -138,6 +143,9 @@ def inference_transformers( num_fewshots: Number of few-shot examples (0, 1, or 5). batch_size: Batch size for inference. seed: Random seed for reproducibility. + limit: Limit the number of questions to evaluate. If an integer, evaluates + the first N samples. If a float between 0.0 and 1.0, evaluates the + first X*100% of samples. If None, evaluates all samples (default). Returns: List of generated SQL results with metadata. @@ -212,22 +220,31 @@ def inference_transformers( # --- Load necessary files --- repo_id = get_repo_id(version) - - questions_path = _maybe_download( - repo_id, - "questions.jsonl", - questions_path - ) - tables_path = _maybe_download( - repo_id, - "tables.jsonl", - tables_path - ) + + questions_path = _maybe_download(repo_id, "questions.jsonl", questions_path) + tables_path = _maybe_download(repo_id, "tables.jsonl", tables_path) questions = load_jsonl(questions_path) tables_list = load_jsonl(tables_path) tables = {t["table_id"]: t for t in tables_list} + # --- Apply limit --- + if limit is not None: + if isinstance(limit, float): + if not (0.0 < limit <= 1.0): + raise ValueError( + f"When a float, `limit` must be between 0.0 and 1.0, got {limit}." + ) + limit = max(1, int(len(questions) * limit)) + if not isinstance(limit, int) or limit < 1: + raise ValueError( + f"`limit` must be a positive integer or a float in (0.0, 1.0], got {limit!r}." + ) + log.info( + f"Limiting evaluation to first {limit} questions out of {len(questions)}" + ) + questions = questions[:limit] + # --- Chat template setup --- use_chat_template = chat_template or getattr(tokenizer, "chat_template", None) if use_chat_template: diff --git a/llmsql/inference/inference_vllm.py b/llmsql/inference/inference_vllm.py index daf677a..f661c6b 100644 --- a/llmsql/inference/inference_vllm.py +++ b/llmsql/inference/inference_vllm.py @@ -42,13 +42,17 @@ os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" from pathlib import Path -from typing import Any +from typing import Any, Literal from dotenv import load_dotenv from tqdm import tqdm from vllm import LLM, SamplingParams -from llmsql.config.config import DEFAULT_WORKDIR_PATH, get_repo_id, DEFAULT_LLMSQL_VERSION +from llmsql.config.config import ( + DEFAULT_LLMSQL_VERSION, + DEFAULT_WORKDIR_PATH, + get_repo_id, +) from llmsql.loggers.logging_config import log from llmsql.utils.inference_utils import _maybe_download, _setup_seed from llmsql.utils.utils import ( @@ -76,11 +80,12 @@ def inference_vllm( do_sample: bool = True, sampling_kwargs: dict[str, Any] | None = None, # === Benchmark Parameters === - version: str = DEFAULT_LLMSQL_VERSION, + version: Literal["1.0", "2.0"] = DEFAULT_LLMSQL_VERSION, output_file: str = "llm_sql_predictions.jsonl", questions_path: str | None = None, tables_path: str | None = None, workdir_path: str = DEFAULT_WORKDIR_PATH, + limit: int | float | None = None, num_fewshots: int = 5, batch_size: int = 8, seed: int = 42, @@ -117,6 +122,9 @@ def inference_vllm( num_fewshots: Number of few-shot examples (0, 1, or 5). batch_size: Number of questions per generation batch. seed: Random seed for reproducibility. + limit: Limit the number of questions to evaluate. If an integer, evaluates + the first N samples. If a float between 0.0 and 1.0, evaluates the + first X*100% of samples. If None, evaluates all samples (default). Returns: List of dicts containing `question_id` and generated `completion`. @@ -132,24 +140,33 @@ def inference_vllm( # --- load input data --- log.info("Preparing questions and tables...") - - repo_id = get_repo_id(version) - questions_path = _maybe_download( - repo_id, - "questions.jsonl", - questions_path - ) - tables_path = _maybe_download( - repo_id, - "tables.jsonl", - tables_path - ) + repo_id = get_repo_id(version) + + questions_path = _maybe_download(repo_id, "questions.jsonl", questions_path) + tables_path = _maybe_download(repo_id, "tables.jsonl", tables_path) questions = load_jsonl(questions_path) tables_list = load_jsonl(tables_path) tables = {t["table_id"]: t for t in tables_list} + # --- Apply limit --- + if limit is not None: + if isinstance(limit, float): + if not (0.0 < limit <= 1.0): + raise ValueError( + f"When a float, `limit` must be between 0.0 and 1.0, got {limit}." + ) + limit = max(1, int(len(questions) * limit)) + if not isinstance(limit, int) or limit < 1: + raise ValueError( + f"`limit` must be a positive integer or a float in (0.0, 1.0], got {limit!r}." + ) + log.info( + f"Limiting evaluation to first {limit} questions out of {len(questions)}" + ) + questions = questions[:limit] + # --- init model --- llm_init_args = { "model": model_name, diff --git a/llmsql/utils/inference_utils.py b/llmsql/utils/inference_utils.py index 4189aad..892ccd9 100644 --- a/llmsql/utils/inference_utils.py +++ b/llmsql/utils/inference_utils.py @@ -29,22 +29,24 @@ def _setup_seed(seed: int) -> None: torch.cuda.manual_seed_all(seed) -def _maybe_download( - repo_id: str, - filename: str, - local_path: str | None -) -> str: +def _maybe_download(repo_id: str, filename: str, local_path: str | None) -> str: if local_path is not None: return local_path + target_path = Path(DEFAULT_WORKDIR_PATH) / filename - if not target_path.exists(): - log.info(f"Downloading {filename} from Hugging Face Hub...") - local_path = hf_hub_download( - repo_id=repo_id, - filename=filename, - repo_type="dataset", - local_dir=DEFAULT_WORKDIR_PATH, - ) - log.info(f"Downloaded {filename} to: {local_path}") - return local_path - return str(target_path) + + if target_path.exists(): + log.info(f"Removing existing path: {target_path}") + if target_path.is_file() or target_path.is_symlink(): + target_path.unlink() + + log.info(f"Downloading {filename} from Hugging Face Hub...") + local_path = hf_hub_download( + repo_id=repo_id, + filename=filename, + repo_type="dataset", + local_dir=DEFAULT_WORKDIR_PATH, + ) + log.info(f"Downloaded {filename} to: {local_path}") + + return local_path diff --git a/pdm.lock b/pdm.lock index 10b38e0..268d620 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "vllm"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:49866b58c866ccafa28447f8dd4323a85c2152ad2196b761929200362b921ad1" +content_hash = "sha256:b04bedef5f2891fa9001f370464e9fe61b569b3f0ab5a707a859df549df4b885" [[metadata.targets]] requires_python = ">=3.10" @@ -245,6 +245,18 @@ files = [ {file = "anyio-4.11.0.tar.gz", hash = "sha256:82a8d0b81e318cc5ce71a5f1f8b5c4e63619620b63141ef8c995fa0db95a57c4"}, ] +[[package]] +name = "appnope" +version = "0.1.4" +requires_python = ">=3.6" +summary = "Disable App Nap on macOS >= 10.9" +groups = ["dev"] +marker = "platform_system == \"Darwin\"" +files = [ + {file = "appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c"}, + {file = "appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee"}, +] + [[package]] name = "astor" version = "0.8.1" @@ -270,6 +282,17 @@ files = [ {file = "astroid-4.0.1.tar.gz", hash = "sha256:0d778ec0def05b935e198412e62f9bcca8b3b5c39fdbe50b0ba074005e477aab"}, ] +[[package]] +name = "asttokens" +version = "3.0.1" +requires_python = ">=3.8" +summary = "Annotate AST trees with source code positions" +groups = ["dev"] +files = [ + {file = "asttokens-3.0.1-py3-none-any.whl", hash = "sha256:15a3ebc0f43c2d0a50eeafea25e19046c68398e487b9f1f5b517f7c0f40f976a"}, + {file = "asttokens-3.0.1.tar.gz", hash = "sha256:71a4ee5de0bde6a31d64f6b13f2293ac190344478f081c3d1bccfcf5eacb0cb7"}, +] + [[package]] name = "async-timeout" version = "5.0.1" @@ -458,7 +481,7 @@ name = "cffi" version = "2.0.0" requires_python = ">=3.9" summary = "Foreign Function Interface for Python calling C code." -groups = ["vllm"] +groups = ["dev", "vllm"] marker = "implementation_name == \"pypy\"" dependencies = [ "pycparser; implementation_name != \"PyPy\"", @@ -676,6 +699,17 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "comm" +version = "0.2.3" +requires_python = ">=3.8" +summary = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." +groups = ["dev"] +files = [ + {file = "comm-0.2.3-py3-none-any.whl", hash = "sha256:c615d91d75f7f04f095b30d1c1711babd43bdc6419c1be9886a85f2f4e489417"}, + {file = "comm-0.2.3.tar.gz", hash = "sha256:2dc8048c10962d55d7ad693be1e7045d891b7ce8d999c97963a5e3e99c055971"}, +] + [[package]] name = "compressed-tensors" version = "0.9.2" @@ -951,6 +985,48 @@ files = [ {file = "datasets-4.2.0.tar.gz", hash = "sha256:8333a7db9f3bb8044c1b819a35d4e3e2809596c837793b0921382efffdc36e78"}, ] +[[package]] +name = "debugpy" +version = "1.8.20" +requires_python = ">=3.8" +summary = "An implementation of the Debug Adapter Protocol for Python" +groups = ["dev"] +files = [ + {file = "debugpy-1.8.20-cp310-cp310-macosx_15_0_x86_64.whl", hash = "sha256:157e96ffb7f80b3ad36d808646198c90acb46fdcfd8bb1999838f0b6f2b59c64"}, + {file = "debugpy-1.8.20-cp310-cp310-manylinux_2_34_x86_64.whl", hash = "sha256:c1178ae571aff42e61801a38b007af504ec8e05fde1c5c12e5a7efef21009642"}, + {file = "debugpy-1.8.20-cp310-cp310-win32.whl", hash = "sha256:c29dd9d656c0fbd77906a6e6a82ae4881514aa3294b94c903ff99303e789b4a2"}, + {file = "debugpy-1.8.20-cp310-cp310-win_amd64.whl", hash = "sha256:3ca85463f63b5dd0aa7aaa933d97cbc47c174896dcae8431695872969f981893"}, + {file = "debugpy-1.8.20-cp311-cp311-macosx_15_0_universal2.whl", hash = "sha256:eada6042ad88fa1571b74bd5402ee8b86eded7a8f7b827849761700aff171f1b"}, + {file = "debugpy-1.8.20-cp311-cp311-manylinux_2_34_x86_64.whl", hash = "sha256:7de0b7dfeedc504421032afba845ae2a7bcc32ddfb07dae2c3ca5442f821c344"}, + {file = "debugpy-1.8.20-cp311-cp311-win32.whl", hash = "sha256:773e839380cf459caf73cc533ea45ec2737a5cc184cf1b3b796cd4fd98504fec"}, + {file = "debugpy-1.8.20-cp311-cp311-win_amd64.whl", hash = "sha256:1f7650546e0eded1902d0f6af28f787fa1f1dbdbc97ddabaf1cd963a405930cb"}, + {file = "debugpy-1.8.20-cp312-cp312-macosx_15_0_universal2.whl", hash = "sha256:4ae3135e2089905a916909ef31922b2d733d756f66d87345b3e5e52b7a55f13d"}, + {file = "debugpy-1.8.20-cp312-cp312-manylinux_2_34_x86_64.whl", hash = "sha256:88f47850a4284b88bd2bfee1f26132147d5d504e4e86c22485dfa44b97e19b4b"}, + {file = "debugpy-1.8.20-cp312-cp312-win32.whl", hash = "sha256:4057ac68f892064e5f98209ab582abfee3b543fb55d2e87610ddc133a954d390"}, + {file = "debugpy-1.8.20-cp312-cp312-win_amd64.whl", hash = "sha256:a1a8f851e7cf171330679ef6997e9c579ef6dd33c9098458bd9986a0f4ca52e3"}, + {file = "debugpy-1.8.20-cp313-cp313-macosx_15_0_universal2.whl", hash = "sha256:5dff4bb27027821fdfcc9e8f87309a28988231165147c31730128b1c983e282a"}, + {file = "debugpy-1.8.20-cp313-cp313-manylinux_2_34_x86_64.whl", hash = "sha256:84562982dd7cf5ebebfdea667ca20a064e096099997b175fe204e86817f64eaf"}, + {file = "debugpy-1.8.20-cp313-cp313-win32.whl", hash = "sha256:da11dea6447b2cadbf8ce2bec59ecea87cc18d2c574980f643f2d2dfe4862393"}, + {file = "debugpy-1.8.20-cp313-cp313-win_amd64.whl", hash = "sha256:eb506e45943cab2efb7c6eafdd65b842f3ae779f020c82221f55aca9de135ed7"}, + {file = "debugpy-1.8.20-cp314-cp314-macosx_15_0_universal2.whl", hash = "sha256:9c74df62fc064cd5e5eaca1353a3ef5a5d50da5eb8058fcef63106f7bebe6173"}, + {file = "debugpy-1.8.20-cp314-cp314-manylinux_2_34_x86_64.whl", hash = "sha256:077a7447589ee9bc1ff0cdf443566d0ecf540ac8aa7333b775ebcb8ce9f4ecad"}, + {file = "debugpy-1.8.20-cp314-cp314-win32.whl", hash = "sha256:352036a99dd35053b37b7803f748efc456076f929c6a895556932eaf2d23b07f"}, + {file = "debugpy-1.8.20-cp314-cp314-win_amd64.whl", hash = "sha256:a98eec61135465b062846112e5ecf2eebb855305acc1dfbae43b72903b8ab5be"}, + {file = "debugpy-1.8.20-py2.py3-none-any.whl", hash = "sha256:5be9bed9ae3be00665a06acaa48f8329d2b9632f15fd09f6a9a8c8d9907e54d7"}, + {file = "debugpy-1.8.20.tar.gz", hash = "sha256:55bc8701714969f1ab89a6d5f2f3d40c36f91b2cbe2f65d98bf8196f6a6a2c33"}, +] + +[[package]] +name = "decorator" +version = "5.2.1" +requires_python = ">=3.8" +summary = "Decorators for Humans" +groups = ["dev"] +files = [ + {file = "decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a"}, + {file = "decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360"}, +] + [[package]] name = "depyf" version = "0.18.0" @@ -1072,6 +1148,17 @@ files = [ {file = "exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88"}, ] +[[package]] +name = "executing" +version = "2.2.1" +requires_python = ">=3.8" +summary = "Get the currently executing AST node of a frame, and other information" +groups = ["dev"] +files = [ + {file = "executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017"}, + {file = "executing-2.2.1.tar.gz", hash = "sha256:3632cc370565f6648cc328b32435bd120a1e4ebb20c77e3fdde9a13cd1e533c4"}, +] + [[package]] name = "fastapi" version = "0.119.0" @@ -1606,6 +1693,88 @@ files = [ {file = "interegular-0.3.3.tar.gz", hash = "sha256:d9b697b21b34884711399ba0f0376914b81899ce670032486d0d048344a76600"}, ] +[[package]] +name = "ipykernel" +version = "7.2.0" +requires_python = ">=3.10" +summary = "IPython Kernel for Jupyter" +groups = ["dev"] +dependencies = [ + "appnope>=0.1.2; platform_system == \"Darwin\"", + "comm>=0.1.1", + "debugpy>=1.6.5", + "ipython>=7.23.1", + "jupyter-client>=8.8.0", + "jupyter-core!=6.0.*,>=5.1", + "matplotlib-inline>=0.1", + "nest-asyncio>=1.4", + "packaging>=22", + "psutil>=5.7", + "pyzmq>=25", + "tornado>=6.4.1", + "traitlets>=5.4.0", +] +files = [ + {file = "ipykernel-7.2.0-py3-none-any.whl", hash = "sha256:3bbd4420d2b3cc105cbdf3756bfc04500b1e52f090a90716851f3916c62e1661"}, + {file = "ipykernel-7.2.0.tar.gz", hash = "sha256:18ed160b6dee2cbb16e5f3575858bc19d8f1fe6046a9a680c708494ce31d909e"}, +] + +[[package]] +name = "ipython" +version = "8.38.0" +requires_python = ">=3.10" +summary = "IPython: Productive Interactive Computing" +groups = ["dev"] +dependencies = [ + "colorama; sys_platform == \"win32\"", + "decorator", + "exceptiongroup; python_version < \"3.11\"", + "jedi>=0.16", + "matplotlib-inline", + "pexpect>4.3; sys_platform != \"win32\" and sys_platform != \"emscripten\"", + "prompt-toolkit<3.1.0,>=3.0.41", + "pygments>=2.4.0", + "stack-data", + "traitlets>=5.13.0", + "typing-extensions>=4.6; python_version < \"3.12\"", +] +files = [ + {file = "ipython-8.38.0-py3-none-any.whl", hash = "sha256:750162629d800ac65bb3b543a14e7a74b0e88063eac9b92124d4b2aa3f6d8e86"}, + {file = "ipython-8.38.0.tar.gz", hash = "sha256:9cfea8c903ce0867cc2f23199ed8545eb741f3a69420bfcf3743ad1cec856d39"}, +] + +[[package]] +name = "ipywidgets" +version = "8.1.8" +requires_python = ">=3.7" +summary = "Jupyter interactive widgets" +groups = ["dev"] +dependencies = [ + "comm>=0.1.3", + "ipython>=6.1.0", + "jupyterlab-widgets~=3.0.15", + "traitlets>=4.3.1", + "widgetsnbextension~=4.0.14", +] +files = [ + {file = "ipywidgets-8.1.8-py3-none-any.whl", hash = "sha256:ecaca67aed704a338f88f67b1181b58f821ab5dc89c1f0f5ef99db43c1c2921e"}, + {file = "ipywidgets-8.1.8.tar.gz", hash = "sha256:61f969306b95f85fba6b6986b7fe45d73124d1d9e3023a8068710d47a22ea668"}, +] + +[[package]] +name = "jedi" +version = "0.19.2" +requires_python = ">=3.6" +summary = "An autocompletion tool for Python that can be used for text editors." +groups = ["dev"] +dependencies = [ + "parso<0.9.0,>=0.8.4", +] +files = [ + {file = "jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9"}, + {file = "jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0"}, +] + [[package]] name = "jinja2" version = "3.1.6" @@ -1725,6 +1894,50 @@ files = [ {file = "jsonschema_specifications-2025.9.1.tar.gz", hash = "sha256:b540987f239e745613c7a9176f3edb72b832a4ac465cf02712288397832b5e8d"}, ] +[[package]] +name = "jupyter-client" +version = "8.8.0" +requires_python = ">=3.10" +summary = "Jupyter protocol implementation and client libraries" +groups = ["dev"] +dependencies = [ + "jupyter-core>=5.1", + "python-dateutil>=2.8.2", + "pyzmq>=25.0", + "tornado>=6.4.1", + "traitlets>=5.3", +] +files = [ + {file = "jupyter_client-8.8.0-py3-none-any.whl", hash = "sha256:f93a5b99c5e23a507b773d3a1136bd6e16c67883ccdbd9a829b0bbdb98cd7d7a"}, + {file = "jupyter_client-8.8.0.tar.gz", hash = "sha256:d556811419a4f2d96c869af34e854e3f059b7cc2d6d01a9cd9c85c267691be3e"}, +] + +[[package]] +name = "jupyter-core" +version = "5.9.1" +requires_python = ">=3.10" +summary = "Jupyter core package. A base package on which Jupyter projects rely." +groups = ["dev"] +dependencies = [ + "platformdirs>=2.5", + "traitlets>=5.3", +] +files = [ + {file = "jupyter_core-5.9.1-py3-none-any.whl", hash = "sha256:ebf87fdc6073d142e114c72c9e29a9d7ca03fad818c5d300ce2adc1fb0743407"}, + {file = "jupyter_core-5.9.1.tar.gz", hash = "sha256:4d09aaff303b9566c3ce657f580bd089ff5c91f5f89cf7d8846c3cdf465b5508"}, +] + +[[package]] +name = "jupyterlab-widgets" +version = "3.0.16" +requires_python = ">=3.7" +summary = "Jupyter interactive widgets for JupyterLab" +groups = ["dev"] +files = [ + {file = "jupyterlab_widgets-3.0.16-py3-none-any.whl", hash = "sha256:45fa36d9c6422cf2559198e4db481aa243c7a32d9926b500781c830c80f7ecf8"}, + {file = "jupyterlab_widgets-3.0.16.tar.gz", hash = "sha256:423da05071d55cf27a9e602216d35a3a65a3e41cdf9c5d3b643b814ce38c19e0"}, +] + [[package]] name = "lark" version = "1.2.2" @@ -1903,6 +2116,20 @@ files = [ {file = "markupsafe-3.0.3.tar.gz", hash = "sha256:722695808f4b6457b320fdc131280796bdceb04ab50fe1795cd540799ebe1698"}, ] +[[package]] +name = "matplotlib-inline" +version = "0.2.1" +requires_python = ">=3.9" +summary = "Inline Matplotlib backend for Jupyter" +groups = ["dev"] +dependencies = [ + "traitlets", +] +files = [ + {file = "matplotlib_inline-0.2.1-py3-none-any.whl", hash = "sha256:d56ce5156ba6085e00a9d54fead6ed29a9c47e215cd1bba2e976ef39f5710a76"}, + {file = "matplotlib_inline-0.2.1.tar.gz", hash = "sha256:e1ee949c340d771fc39e241ea75683deb94762c8fa5f2927ec57c83c4dffa9fe"}, +] + [[package]] name = "mdurl" version = "0.1.2" @@ -2296,7 +2523,7 @@ name = "nest-asyncio" version = "1.6.0" requires_python = ">=3.5" summary = "Patch asyncio to allow nested event loops" -groups = ["vllm"] +groups = ["dev", "vllm"] files = [ {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"}, {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, @@ -2809,6 +3036,17 @@ files = [ {file = "pandas-2.3.3.tar.gz", hash = "sha256:e05e1af93b977f7eafa636d043f9f94c7ee3ac81af99c13508215942e64c993b"}, ] +[[package]] +name = "parso" +version = "0.8.6" +requires_python = ">=3.6" +summary = "A Python Parser" +groups = ["dev"] +files = [ + {file = "parso-0.8.6-py2.py3-none-any.whl", hash = "sha256:2c549f800b70a5c4952197248825584cb00f033b29c692671d3bf08bf380baff"}, + {file = "parso-0.8.6.tar.gz", hash = "sha256:2b9a0332696df97d454fa67b81618fd69c35a7b90327cbe6ba5c92d2c68a7bfd"}, +] + [[package]] name = "partial-json-parser" version = "0.2.1.1.post6" @@ -2831,6 +3069,20 @@ files = [ {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, ] +[[package]] +name = "pexpect" +version = "4.9.0" +summary = "Pexpect allows easy control of interactive console applications." +groups = ["dev"] +marker = "sys_platform != \"win32\" and sys_platform != \"emscripten\"" +dependencies = [ + "ptyprocess>=0.5", +] +files = [ + {file = "pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523"}, + {file = "pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f"}, +] + [[package]] name = "pillow" version = "12.0.0" @@ -2997,6 +3249,20 @@ files = [ {file = "prometheus_fastapi_instrumentator-7.1.0.tar.gz", hash = "sha256:be7cd61eeea4e5912aeccb4261c6631b3f227d8924542d79eaf5af3f439cbe5e"}, ] +[[package]] +name = "prompt-toolkit" +version = "3.0.52" +requires_python = ">=3.8" +summary = "Library for building powerful interactive command lines in Python" +groups = ["dev"] +dependencies = [ + "wcwidth", +] +files = [ + {file = "prompt_toolkit-3.0.52-py3-none-any.whl", hash = "sha256:9aac639a3bbd33284347de5ad8d68ecc044b91a762dc39b7c21095fcd6a19955"}, + {file = "prompt_toolkit-3.0.52.tar.gz", hash = "sha256:28cde192929c8e7321de85de1ddbe736f1375148b02f2e17edd840042b1be855"}, +] + [[package]] name = "propcache" version = "0.4.1" @@ -3135,7 +3401,7 @@ name = "psutil" version = "7.1.0" requires_python = ">=3.6" summary = "Cross-platform lib for process and system monitoring." -groups = ["default", "vllm"] +groups = ["default", "dev", "vllm"] files = [ {file = "psutil-7.1.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:76168cef4397494250e9f4e73eb3752b146de1dd950040b29186d0cce1d5ca13"}, {file = "psutil-7.1.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:5d007560c8c372efdff9e4579c2846d71de737e4605f611437255e81efcca2c5"}, @@ -3148,6 +3414,27 @@ files = [ {file = "psutil-7.1.0.tar.gz", hash = "sha256:655708b3c069387c8b77b072fc429a57d0e214221d01c0a772df7dfedcb3bcd2"}, ] +[[package]] +name = "ptyprocess" +version = "0.7.0" +summary = "Run a subprocess in a pseudo terminal" +groups = ["dev"] +marker = "sys_platform != \"win32\" and sys_platform != \"emscripten\"" +files = [ + {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, + {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"}, +] + +[[package]] +name = "pure-eval" +version = "0.2.3" +summary = "Safely evaluate AST nodes without side effects" +groups = ["dev"] +files = [ + {file = "pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0"}, + {file = "pure_eval-0.2.3.tar.gz", hash = "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42"}, +] + [[package]] name = "py-cpuinfo" version = "9.0.0" @@ -3222,7 +3509,7 @@ name = "pycparser" version = "2.23" requires_python = ">=3.8" summary = "C parser in Python" -groups = ["vllm"] +groups = ["dev", "vllm"] marker = "implementation_name == \"pypy\"" files = [ {file = "pycparser-2.23-py3-none-any.whl", hash = "sha256:e5c6e8d3fbad53479cab09ac03729e0a9faf2bee3db8208a550daf5af81a5934"}, @@ -3452,7 +3739,7 @@ name = "python-dateutil" version = "2.9.0.post0" requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" summary = "Extensions to the standard Python datetime module" -groups = ["default"] +groups = ["default", "dev"] dependencies = [ "six>=1.5", ] @@ -3578,7 +3865,7 @@ name = "pyzmq" version = "27.1.0" requires_python = ">=3.8" summary = "Python bindings for 0MQ" -groups = ["vllm"] +groups = ["dev", "vllm"] dependencies = [ "cffi; implementation_name == \"pypy\"", ] @@ -4342,7 +4629,7 @@ name = "six" version = "1.17.0" requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" summary = "Python 2 and 3 compatibility utilities" -groups = ["default", "vllm"] +groups = ["default", "dev", "vllm"] files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, @@ -4557,6 +4844,21 @@ files = [ {file = "sphinxcontrib_serializinghtml-2.0.0.tar.gz", hash = "sha256:e9d912827f872c029017a53f0ef2180b327c3f7fd23c87229f7a8e8b70031d4d"}, ] +[[package]] +name = "stack-data" +version = "0.6.3" +summary = "Extract data from python stack frames and tracebacks for informative displays" +groups = ["dev"] +dependencies = [ + "asttokens>=2.1.0", + "executing>=1.2.0", + "pure-eval", +] +files = [ + {file = "stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695"}, + {file = "stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9"}, +] + [[package]] name = "starlette" version = "0.48.0" @@ -4838,6 +5140,27 @@ files = [ {file = "torchvision-0.21.0-cp313-cp313-win_amd64.whl", hash = "sha256:9147f5e096a9270684e3befdee350f3cacafd48e0c54ab195f45790a9c146d67"}, ] +[[package]] +name = "tornado" +version = "6.5.4" +requires_python = ">=3.9" +summary = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." +groups = ["dev"] +files = [ + {file = "tornado-6.5.4-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:d6241c1a16b1c9e4cc28148b1cda97dd1c6cb4fb7068ac1bedc610768dff0ba9"}, + {file = "tornado-6.5.4-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:2d50f63dda1d2cac3ae1fa23d254e16b5e38153758470e9956cbc3d813d40843"}, + {file = "tornado-6.5.4-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1cf66105dc6acb5af613c054955b8137e34a03698aa53272dbda4afe252be17"}, + {file = "tornado-6.5.4-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50ff0a58b0dc97939d29da29cd624da010e7f804746621c78d14b80238669335"}, + {file = "tornado-6.5.4-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5fb5e04efa54cf0baabdd10061eb4148e0be137166146fff835745f59ab9f7f"}, + {file = "tornado-6.5.4-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9c86b1643b33a4cd415f8d0fe53045f913bf07b4a3ef646b735a6a86047dda84"}, + {file = "tornado-6.5.4-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:6eb82872335a53dd063a4f10917b3efd28270b56a33db69009606a0312660a6f"}, + {file = "tornado-6.5.4-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6076d5dda368c9328ff41ab5d9dd3608e695e8225d1cd0fd1e006f05da3635a8"}, + {file = "tornado-6.5.4-cp39-abi3-win32.whl", hash = "sha256:1768110f2411d5cd281bac0a090f707223ce77fd110424361092859e089b38d1"}, + {file = "tornado-6.5.4-cp39-abi3-win_amd64.whl", hash = "sha256:fa07d31e0cd85c60713f2b995da613588aa03e1303d75705dca6af8babc18ddc"}, + {file = "tornado-6.5.4-cp39-abi3-win_arm64.whl", hash = "sha256:053e6e16701eb6cbe641f308f4c1a9541f91b6261991160391bfc342e8a551a1"}, + {file = "tornado-6.5.4.tar.gz", hash = "sha256:a22fa9047405d03260b483980635f0b041989d8bcc9a313f8fe18b411d84b1d7"}, +] + [[package]] name = "tqdm" version = "4.67.1" @@ -4852,6 +5175,17 @@ files = [ {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"}, ] +[[package]] +name = "traitlets" +version = "5.14.3" +requires_python = ">=3.8" +summary = "Traitlets Python configuration system" +groups = ["dev"] +files = [ + {file = "traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f"}, + {file = "traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7"}, +] + [[package]] name = "transformers" version = "4.57.1" @@ -5252,6 +5586,17 @@ files = [ {file = "watchfiles-1.1.1.tar.gz", hash = "sha256:a173cb5c16c4f40ab19cecf48a534c409f7ea983ab8fed0741304a1c0a31b3f2"}, ] +[[package]] +name = "wcwidth" +version = "0.6.0" +requires_python = ">=3.8" +summary = "Measures the displayed width of unicode strings in a terminal" +groups = ["dev"] +files = [ + {file = "wcwidth-0.6.0-py3-none-any.whl", hash = "sha256:1a3a1e510b553315f8e146c54764f4fb6264ffad731b3d78088cdb1478ffbdad"}, + {file = "wcwidth-0.6.0.tar.gz", hash = "sha256:cdc4e4262d6ef9a1a57e018384cbeb1208d8abbc64176027e2c2455c81313159"}, +] + [[package]] name = "websockets" version = "15.0.1" @@ -5313,6 +5658,17 @@ files = [ {file = "websockets-15.0.1.tar.gz", hash = "sha256:82544de02076bafba038ce055ee6412d68da13ab47f0c60cab827346de828dee"}, ] +[[package]] +name = "widgetsnbextension" +version = "4.0.15" +requires_python = ">=3.7" +summary = "Jupyter interactive widgets for Jupyter Notebook" +groups = ["dev"] +files = [ + {file = "widgetsnbextension-4.0.15-py3-none-any.whl", hash = "sha256:8156704e4346a571d9ce73b84bee86a29906c9abfd7223b7228a28899ccf3366"}, + {file = "widgetsnbextension-4.0.15.tar.gz", hash = "sha256:de8610639996f1567952d763a5a41af8af37f2575a41f9852a38f947eb82a3b9"}, +] + [[package]] name = "xformers" version = "0.0.29.post2" diff --git a/pyproject.toml b/pyproject.toml index c0d82a1..8b79e97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,8 @@ dev = [ "sphinx-autobuild>=2024.10.3", "sphinx-copybutton>=0.5.2", "pytest-mock>=3.15.1", + "ipykernel>=7.2.0", + "ipywidgets>=8.1.8", ] vllm = [ "vllm>=0.4.2", diff --git a/tests/inference/test_limit_argument.py b/tests/inference/test_limit_argument.py new file mode 100644 index 0000000..e2d2e09 --- /dev/null +++ b/tests/inference/test_limit_argument.py @@ -0,0 +1,270 @@ +import json +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +import llmsql.inference.inference_transformers as transformers_mod +import llmsql.inference.inference_vllm as vllm_mod + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +QUESTIONS = [ + {"question_id": f"q{i}", "question": f"Question {i}?", "table_id": "t1"} + for i in range(1, 6) # 5 questions total +] +TABLES = [{"table_id": "t1", "header": ["col"], "types": ["text"], "rows": [["foo"]]}] + + +def _write_jsonl(path, records): + path.write_text("\n".join(json.dumps(r) for r in records)) + + +def _patch_common_vllm(monkeypatch, tmp_path): + """Patch all vLLM module-level dependencies.""" + monkeypatch.setattr( + vllm_mod, + "load_jsonl", + lambda path: [json.loads(line) for line in Path(path).read_text().splitlines()], + ) + monkeypatch.setattr( + vllm_mod, "overwrite_jsonl", lambda path: Path(path).write_text("") + ) + monkeypatch.setattr( + vllm_mod, + "save_jsonl_lines", + lambda path, lines: Path(path) + .open("a") + .write("\n".join(json.dumps(line) for line in lines) + "\n"), + ) + monkeypatch.setattr( + vllm_mod, + "choose_prompt_builder", + lambda shots: lambda q, h, t, r: f"PROMPT: {q}", + ) + + fake_llm = MagicMock() + fake_llm.generate.side_effect = lambda prompts, _params: [ + MagicMock(outputs=[MagicMock(text=f"SELECT {i}")]) for i in range(len(prompts)) + ] + monkeypatch.setattr(vllm_mod, "LLM", lambda *a, **kw: fake_llm) + + +def _patch_common_transformers(monkeypatch, tmp_path): + """Patch all Transformers module-level dependencies.""" + monkeypatch.setattr( + transformers_mod, + "load_jsonl", + lambda path: [json.loads(line) for line in Path(path).read_text().splitlines()], + ) + monkeypatch.setattr( + transformers_mod, "overwrite_jsonl", lambda path: Path(path).write_text("") + ) + monkeypatch.setattr( + transformers_mod, + "save_jsonl_lines", + lambda path, lines: Path(path) + .open("a") + .write("\n".join(json.dumps(line) for line in lines) + "\n"), + ) + monkeypatch.setattr( + transformers_mod, + "choose_prompt_builder", + lambda shots: lambda q, h, t, r: f"PROMPT: {q}", + ) + + fake_tokenizer = MagicMock() + fake_tokenizer.pad_token = "" + fake_tokenizer.pad_token_id = 0 + fake_tokenizer.chat_template = None + fake_tokenizer.return_value = {"input_ids": MagicMock()} + + fake_model = MagicMock() + # generate returns tensors of shape (batch, input_len + new_tokens) + fake_model.device = "cpu" + fake_model.generate.side_effect = lambda **kw: [ + [0] * (len(ids) + 5) for ids in kw["input_ids"] + ] + + monkeypatch.setattr( + transformers_mod, + "AutoModelForCausalLM", + MagicMock(from_pretrained=MagicMock(return_value=fake_model)), + ) + monkeypatch.setattr( + transformers_mod, + "AutoTokenizer", + MagicMock(from_pretrained=MagicMock(return_value=fake_tokenizer)), + ) + + +# --------------------------------------------------------------------------- +# vLLM limit tests +# --------------------------------------------------------------------------- + + +class TestInferenceVllmLimit: + @pytest.mark.asyncio + async def test_limit_integer_restricts_results(self, monkeypatch, tmp_path): + """Integer limit returns only the first N results.""" + qpath, tpath = tmp_path / "questions.jsonl", tmp_path / "tables.jsonl" + _write_jsonl(qpath, QUESTIONS) + _write_jsonl(tpath, TABLES) + _patch_common_vllm(monkeypatch, tmp_path) + + results = vllm_mod.inference_vllm( + model_name="dummy", + output_file=str(tmp_path / "out.jsonl"), + questions_path=str(qpath), + tables_path=str(tpath), + limit=3, + ) + + assert len(results) == 3 + assert [r["question_id"] for r in results] == ["q1", "q2", "q3"] + + @pytest.mark.asyncio + async def test_limit_float_restricts_results(self, monkeypatch, tmp_path): + """Float limit of 0.4 on 5 questions returns first 2 (floor, min 1).""" + qpath, tpath = tmp_path / "questions.jsonl", tmp_path / "tables.jsonl" + _write_jsonl(qpath, QUESTIONS) + _write_jsonl(tpath, TABLES) + _patch_common_vllm(monkeypatch, tmp_path) + + results = vllm_mod.inference_vllm( + model_name="dummy", + output_file=str(tmp_path / "out.jsonl"), + questions_path=str(qpath), + tables_path=str(tpath), + limit=0.4, + ) + + assert len(results) == 2 + assert results[0]["question_id"] == "q1" + + @pytest.mark.asyncio + async def test_limit_none_uses_all_samples(self, monkeypatch, tmp_path): + """No limit evaluates all questions.""" + qpath, tpath = tmp_path / "questions.jsonl", tmp_path / "tables.jsonl" + _write_jsonl(qpath, QUESTIONS) + _write_jsonl(tpath, TABLES) + _patch_common_vllm(monkeypatch, tmp_path) + + results = vllm_mod.inference_vllm( + model_name="dummy", + output_file=str(tmp_path / "out.jsonl"), + questions_path=str(qpath), + tables_path=str(tpath), + limit=None, + ) + + assert len(results) == len(QUESTIONS) + + @pytest.mark.asyncio + async def test_limit_float_1_uses_all_samples(self, monkeypatch, tmp_path): + """Float limit of 1.0 evaluates all questions.""" + qpath, tpath = tmp_path / "questions.jsonl", tmp_path / "tables.jsonl" + _write_jsonl(qpath, QUESTIONS) + _write_jsonl(tpath, TABLES) + _patch_common_vllm(monkeypatch, tmp_path) + + results = vllm_mod.inference_vllm( + model_name="dummy", + output_file=str(tmp_path / "out.jsonl"), + questions_path=str(qpath), + tables_path=str(tpath), + limit=1.0, + ) + + assert len(results) == len(QUESTIONS) + + @pytest.mark.asyncio + async def test_limit_invalid_float_raises(self, monkeypatch, tmp_path): + """Float outside (0.0, 1.0] raises ValueError.""" + qpath, tpath = tmp_path / "questions.jsonl", tmp_path / "tables.jsonl" + _write_jsonl(qpath, QUESTIONS) + _write_jsonl(tpath, TABLES) + _patch_common_vllm(monkeypatch, tmp_path) + + with pytest.raises(ValueError, match="0.0 and 1.0"): + vllm_mod.inference_vllm( + model_name="dummy", + output_file=str(tmp_path / "out.jsonl"), + questions_path=str(qpath), + tables_path=str(tpath), + limit=1.5, + ) + + @pytest.mark.asyncio + async def test_limit_invalid_int_raises(self, monkeypatch, tmp_path): + """Non-positive integer raises ValueError.""" + qpath, tpath = tmp_path / "questions.jsonl", tmp_path / "tables.jsonl" + _write_jsonl(qpath, QUESTIONS) + _write_jsonl(tpath, TABLES) + _patch_common_vllm(monkeypatch, tmp_path) + + with pytest.raises(ValueError): + vllm_mod.inference_vllm( + model_name="dummy", + output_file=str(tmp_path / "out.jsonl"), + questions_path=str(qpath), + tables_path=str(tpath), + limit=0, + ) + + @pytest.mark.asyncio + async def test_limit_larger_than_dataset_uses_all(self, monkeypatch, tmp_path): + """Integer limit larger than dataset size returns all samples.""" + qpath, tpath = tmp_path / "questions.jsonl", tmp_path / "tables.jsonl" + _write_jsonl(qpath, QUESTIONS) + _write_jsonl(tpath, TABLES) + _patch_common_vllm(monkeypatch, tmp_path) + + results = vllm_mod.inference_vllm( + model_name="dummy", + output_file=str(tmp_path / "out.jsonl"), + questions_path=str(qpath), + tables_path=str(tpath), + limit=9999, + ) + + assert len(results) == len(QUESTIONS) + + +# --------------------------------------------------------------------------- +# Transformers limit tests (same cases, different backend) +# --------------------------------------------------------------------------- + + +class TestInferenceTransformersLimit: + def test_limit_invalid_float_raises(self, monkeypatch, tmp_path): + qpath, tpath = tmp_path / "questions.jsonl", tmp_path / "tables.jsonl" + _write_jsonl(qpath, QUESTIONS) + _write_jsonl(tpath, TABLES) + _patch_common_transformers(monkeypatch, tmp_path) + + with pytest.raises(ValueError, match="0.0 and 1.0"): + transformers_mod.inference_transformers( + model_or_model_name_or_path="dummy", + output_file=str(tmp_path / "out.jsonl"), + questions_path=str(qpath), + tables_path=str(tpath), + limit=2.0, + ) + + def test_limit_invalid_int_raises(self, monkeypatch, tmp_path): + qpath, tpath = tmp_path / "questions.jsonl", tmp_path / "tables.jsonl" + _write_jsonl(qpath, QUESTIONS) + _write_jsonl(tpath, TABLES) + _patch_common_transformers(monkeypatch, tmp_path) + + with pytest.raises(ValueError): + transformers_mod.inference_transformers( + model_or_model_name_or_path="dummy", + output_file=str(tmp_path / "out.jsonl"), + questions_path=str(qpath), + tables_path=str(tpath), + limit=-1, + ) diff --git a/tests/utils/test_inference_utils.py b/tests/utils/test_inference_utils.py index 4fd7a0e..4ca241a 100644 --- a/tests/utils/test_inference_utils.py +++ b/tests/utils/test_inference_utils.py @@ -6,7 +6,11 @@ import pytest import torch -from llmsql.config.config import DEFAULT_WORKDIR_PATH, get_repo_id, DEFAULT_LLMSQL_VERSION +from llmsql.config.config import ( + DEFAULT_LLMSQL_VERSION, + DEFAULT_WORKDIR_PATH, + get_repo_id, +) from llmsql.utils import inference_utils as mod @@ -45,31 +49,6 @@ async def test_setup_seed(monkeypatch): assert a1 == a2 -@pytest.mark.asyncio -async def test_maybe_download_existing_file(tmp_path, monkeypatch): - """_maybe_download returns existing path without calling hf_hub_download.""" - existing = tmp_path / "questions.jsonl" - existing.write_text("dummy") - monkeypatch.setattr(mod, "hf_hub_download", lambda *a, **kw: "FAIL") - # Should return local path directly - path = mod._maybe_download( - get_repo_id(DEFAULT_LLMSQL_VERSION), - "questions.jsonl", - local_path=str(existing) - ) - assert path == str(existing) - - # Should also return target_path if file exists in DEFAULT_WORKDIR_PATH - monkeypatch.setattr(mod, "hf_hub_download", lambda *a, **kw: "FAIL") - monkeypatch.setattr(mod, "DEFAULT_WORKDIR_PATH", str(tmp_path)) - path2 = mod._maybe_download( - get_repo_id(DEFAULT_LLMSQL_VERSION), - "questions.jsonl", - local_path=None - ) - assert Path(path2).exists() or path2.endswith("questions.jsonl") - - @pytest.mark.asyncio async def test_maybe_download_calls_hf_hub(monkeypatch, tmp_path): """_maybe_download downloads file if missing.""" @@ -87,9 +66,7 @@ def fake_hf_hub_download(**kwargs): monkeypatch.setattr(mod, "hf_hub_download", fake_hf_hub_download) path = mod._maybe_download( - get_repo_id(DEFAULT_LLMSQL_VERSION), - filename, - local_path=None + get_repo_id(DEFAULT_LLMSQL_VERSION), filename, local_path=None ) assert Path(path).exists() assert called["repo_id"] == get_repo_id(DEFAULT_LLMSQL_VERSION) From 9dc932a1f235f53ea79018daff7a0ca407c0b9fe Mon Sep 17 00:00:00 2001 From: Dzmitry Pihulski Date: Mon, 23 Feb 2026 15:10:53 +0100 Subject: [PATCH 16/16] fix: cli args parser added --- llmsql/__main__.py | 170 +----------------------------- llmsql/_cli/__init__.py | 7 ++ llmsql/_cli/evaluate.py | 0 llmsql/_cli/inference.py | 215 ++++++++++++++++++++++++++++++++++++++ llmsql/_cli/llmsql_cli.py | 62 +++++++++++ llmsql/_cli/subparsers.py | 24 +++++ tests/cli/test_cli.py | 140 +++++++++++++++++++++++++ tests/conftest.py | 13 ++- tests/test_main.py | 22 ++++ 9 files changed, 487 insertions(+), 166 deletions(-) create mode 100644 llmsql/_cli/__init__.py create mode 100644 llmsql/_cli/evaluate.py create mode 100644 llmsql/_cli/inference.py create mode 100644 llmsql/_cli/llmsql_cli.py create mode 100644 llmsql/_cli/subparsers.py create mode 100644 tests/cli/test_cli.py create mode 100644 tests/test_main.py diff --git a/llmsql/__main__.py b/llmsql/__main__.py index 30b5c80..c3e8513 100644 --- a/llmsql/__main__.py +++ b/llmsql/__main__.py @@ -1,171 +1,11 @@ -import argparse -import inspect -import json -from llmsql.config.config import get_available_versions, DEFAULT_LLMSQL_VERSION +from llmsql._cli import ParserCLI def main() -> None: - parser = argparse.ArgumentParser(prog="llmsql", description="LLMSQL CLI") - subparsers = parser.add_subparsers(dest="command") - - # ================================================================ - # Inference command - # ================================================================ - inference_examples = r""" -Examples: - - # 1️⃣ Run inference with Transformers backend - llmsql inference --method transformers \ - --model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \ - --output-file outputs/preds_transformers.jsonl \ - --batch-size 8 \ - --num-fewshots 5 - - # 2️⃣ Run inference with vLLM backend - llmsql inference --method vllm \ - --model-name Qwen/Qwen2.5-1.5B-Instruct \ - --output-file outputs/preds_vllm.jsonl \ - --batch-size 8 \ - --num-fewshots 5 - - # 3️⃣ Pass model-specific kwargs (for Transformers) - llmsql inference --method transformers \ - --model-or-model-name-or-path meta-llama/Llama-3-8b-instruct \ - --output-file outputs/llama_preds.jsonl \ - --model-kwargs '{"attn_implementation": "flash_attention_2", "torch_dtype": "bfloat16"}' - - # 4️⃣ Pass LLM init kwargs (for vLLM) - llmsql inference --method vllm \ - --model-name mistralai/Mixtral-8x7B-Instruct-v0.1 \ - --output-file outputs/mixtral_preds.jsonl \ - --llm-kwargs '{"max_model_len": 4096, "gpu_memory_utilization": 0.9}' - - # 5️⃣ Override generation parameters dynamically - llmsql inference --method transformers \ - --model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \ - --output-file outputs/temp_0.9.jsonl \ - --temperature 0.9 \ - --generation-kwargs '{"do_sample": true, "top_p": 0.9, "top_k": 40}' - - # 6️⃣ Specify llmsql version (2.0 by default) - llmsql inference --version 1.0 --method transformers \ - --model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \ - --output-file outputs/preds_transformers.jsonl \ - --batch-size 8 \ - --num-fewshots 5 -""" - - inf_parser = subparsers.add_parser( - "inference", - help="Run inference using either Transformers or vLLM backend.", - description="Run SQL generation using a chosen inference method " - "(either 'transformers' or 'vllm').", - epilog=inference_examples, - formatter_class=argparse.RawTextHelpFormatter, - ) - - inf_parser.add_argument( - "--method", - type=str, - required=True, - choices=["transformers", "vllm"], - help="Inference backend to use ('transformers' or 'vllm').", - ) - - inf_parser.add_argument( - "--version", - type=str, - default=DEFAULT_LLMSQL_VERSION, - choices=get_available_versions(), - help="Run inference using available version of LLMSQL (2.0 by default)", - ) - - # ================================================================ - # Parse CLI - # ================================================================ - args, extra = parser.parse_known_args() - - # ------------------------------------------------ - # Inference - # ------------------------------------------------ - if args.command == "inference": - if args.method == "vllm": - from llmsql import inference_vllm as inference_fn - elif args.method == "transformers": - from llmsql import inference_transformers as inference_fn # type: ignore - else: - raise ValueError(f"Unknown inference method: {args.method}") - - # Dynamically create parser from the function signature - fn_parser = argparse.ArgumentParser( - prog=f"llmsql inference --method {args.method}", - description=f"Run inference using {args.method} backend", - ) - - sig = inspect.signature(inference_fn) - for name, param in sig.parameters.items(): - if param.kind == inspect.Parameter.VAR_KEYWORD: - fn_parser.add_argument( - "--llm-kwargs", - default="{}", - help="Additional LLM kwargs as a JSON string, e.g. '{\"top_p\": 0.9}'", - ) - fn_parser.add_argument( - "--generate-kwargs", - default="{}", - help="", - ) - continue - arg_name = f"--{name.replace('_', '-')}" - default = param.default - if default is inspect.Parameter.empty: - fn_parser.add_argument(arg_name, required=True) - else: - if isinstance(default, bool): - fn_parser.add_argument( - arg_name, - action="store_true" if not default else "store_false", - help=f"(default: {default})", - ) - elif default is None: - fn_parser.add_argument(arg_name, type=str, default=None) - else: - fn_parser.add_argument( - arg_name, type=type(default), default=default - ) - - fn_args = fn_parser.parse_args(extra) - fn_kwargs = vars(fn_args) - - if "llm_kwargs" in fn_kwargs and isinstance(fn_kwargs["llm_kwargs"], str): - try: - fn_kwargs["llm_kwargs"] = json.loads(fn_kwargs["llm_kwargs"]) - except json.JSONDecodeError: - print("⚠️ Could not parse --llm-kwargs JSON, passing as string.") - - if fn_kwargs.get("model_kwargs") is not None: - try: - fn_kwargs["model_kwargs"] = json.loads(fn_kwargs["model_kwargs"]) - except json.JSONDecodeError: - raise - - if fn_kwargs.get("generation_kwargs") is not None: - try: - fn_kwargs["generation_kwargs"] = json.loads( - fn_kwargs["generation_kwargs"] - ) - except json.JSONDecodeError: - raise - - print(f"🔹 Running {args.method} inference with arguments:") - for k, v in fn_kwargs.items(): - print(f" {k}: {v}") - - results = inference_fn(**fn_kwargs) - print(f"✅ Inference complete. Generated {len(results)} results.") - - else: - parser.print_help() + """Main CLI entry point.""" + parser = ParserCLI() + args = parser.parse_args() + parser.execute(args) if __name__ == "__main__": diff --git a/llmsql/_cli/__init__.py b/llmsql/_cli/__init__.py new file mode 100644 index 0000000..ca1b385 --- /dev/null +++ b/llmsql/_cli/__init__.py @@ -0,0 +1,7 @@ +""" +CLI subcommands to run from the terminal. +""" + +from .llmsql_cli import ParserCLI + +__all__ = ["ParserCLI"] diff --git a/llmsql/_cli/evaluate.py b/llmsql/_cli/evaluate.py new file mode 100644 index 0000000..e69de29 diff --git a/llmsql/_cli/inference.py b/llmsql/_cli/inference.py new file mode 100644 index 0000000..eacf6d0 --- /dev/null +++ b/llmsql/_cli/inference.py @@ -0,0 +1,215 @@ +import argparse +import json +from typing import Any + +from llmsql._cli.subparsers import SubCommand + + +def parse_limit(value: str) -> float | int: + try: + if "." in value: + return float(value) + return int(value) + except ValueError as err: + raise argparse.ArgumentTypeError("limit must be int or float") from err + + +class Inference(SubCommand): + """Command for running language model evaluation.""" + + def __init__( + self, subparsers: argparse._SubParsersAction, *args: Any, **kwargs: Any + ) -> None: + self._parser = subparsers.add_parser( + "inference", + help="Run inference", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + inference_subparsers = self._parser.add_subparsers(dest="method", required=True) + + self._parser_transformers = inference_subparsers.add_parser( + "transformers", + help="Use HuggingFace Transformers backend", + ) + + self._parser_vllm = inference_subparsers.add_parser( + "vllm", + help="Use vLLM backend", + ) + + self._add_args() + + self._parser_transformers.set_defaults(func=self._execute_transformers) + self._parser_vllm.set_defaults(func=self._execute_vllm) + + def _add_args(self) -> None: + # ========================= + # COMMON BENCHMARK ARGS + # ========================= + def add_common_benchmark_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--version", default="2.0", choices=["1.0", "2.0"]) + parser.add_argument("--output-file", default="llm_sql_predictions.jsonl") + parser.add_argument("--questions-path") + parser.add_argument("--tables-path") + parser.add_argument("--workdir-path", default="./workdir") + parser.add_argument("--num-fewshots", type=int, default=5) + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--limit", type=parse_limit) + + # ========================= + # COMMON GENERATION ARGS + # ========================= + def add_common_generation_args( + parser: argparse.ArgumentParser, default_temp: float, default_sample: int + ) -> None: + parser.add_argument("--max-new-tokens", type=int, default=256) + parser.add_argument("--temperature", type=float, default=default_temp) + parser.add_argument( + "--do-sample", action="store_true", default=default_sample + ) + + # ========================= + # TRANSFORMERS + # ========================= + self._parser_transformers.add_argument( + "--model-or-model-name-or-path", + required=True, + help="HF model name or local path", + ) + + self._parser_transformers.add_argument("--tokenizer-or-name") + + self._parser_transformers.add_argument( + "--trust-remote-code", + action="store_true", + default=True, + ) + self._parser_transformers.add_argument("--dtype", default="float16") + self._parser_transformers.add_argument("--device-map", default="auto") + self._parser_transformers.add_argument("--hf-token") + self._parser_transformers.add_argument( + "--model-kwargs", + type=json.loads, + help="JSON string for AutoModel kwargs", + ) + self._parser_transformers.add_argument( + "--tokenizer-kwargs", + type=json.loads, + help="JSON string for tokenizer kwargs", + ) + self._parser_transformers.add_argument("--chat-template") + + # Generation + add_common_generation_args(self._parser_transformers, 0.0, False) + self._parser_transformers.add_argument("--top-p", type=float, default=1.0) + self._parser_transformers.add_argument("--top-k", type=int, default=50) + self._parser_transformers.add_argument( + "--generation-kwargs", + type=json.loads, + help="JSON string for generate() kwargs", + ) + + add_common_benchmark_args(self._parser_transformers) + + self._parser_transformers.set_defaults(func=self._execute_transformers) + + # ========================= + # vLLM + # ========================= + self._parser_vllm.add_argument( + "--model-name", + required=True, + help="HF model name or path", + ) + + self._parser_vllm.add_argument( + "--trust-remote-code", + action="store_true", + default=True, + ) + self._parser_vllm.add_argument( + "--tensor-parallel-size", + type=int, + default=1, + ) + self._parser_vllm.add_argument("--hf-token") + self._parser_vllm.add_argument( + "--llm-kwargs", + type=json.loads, + help="JSON string for vllm.LLM kwargs", + ) + self._parser_vllm.add_argument( + "--use-chat-template", + action="store_true", + default=True, + ) + + add_common_generation_args(self._parser_vllm, 1.0, True) + self._parser_vllm.add_argument( + "--sampling-kwargs", + type=json.loads, + help="JSON string for SamplingParams kwargs", + ) + + add_common_benchmark_args(self._parser_vllm) + + self._parser_vllm.set_defaults(func=self._execute_vllm) + + @staticmethod + def _execute_transformers(args: argparse.Namespace) -> None: + from llmsql import inference_transformers + + inference_transformers( + model_or_model_name_or_path=args.model_or_model_name_or_path, + tokenizer_or_name=args.tokenizer_or_name, + trust_remote_code=args.trust_remote_code, + dtype=args.dtype, + device_map=args.device_map, + hf_token=args.hf_token, + model_kwargs=args.model_kwargs, + tokenizer_kwargs=args.tokenizer_kwargs, + chat_template=args.chat_template, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + do_sample=args.do_sample, + top_p=args.top_p, + top_k=args.top_k, + generation_kwargs=args.generation_kwargs, + version=args.version, + output_file=args.output_file, + questions_path=args.questions_path, + tables_path=args.tables_path, + workdir_path=args.workdir_path, + num_fewshots=args.num_fewshots, + batch_size=args.batch_size, + limit=args.limit, + seed=args.seed, + ) + + @staticmethod + def _execute_vllm(args: argparse.Namespace) -> None: + from llmsql import inference_vllm + + inference_vllm( + model_name=args.model_name, + trust_remote_code=args.trust_remote_code, + tensor_parallel_size=args.tensor_parallel_size, + hf_token=args.hf_token, + llm_kwargs=args.llm_kwargs, + use_chat_template=args.use_chat_template, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + do_sample=args.do_sample, + sampling_kwargs=args.sampling_kwargs, + version=args.version, + output_file=args.output_file, + questions_path=args.questions_path, + tables_path=args.tables_path, + workdir_path=args.workdir_path, + limit=args.limit, + num_fewshots=args.num_fewshots, + batch_size=args.batch_size, + seed=args.seed, + ) diff --git a/llmsql/_cli/llmsql_cli.py b/llmsql/_cli/llmsql_cli.py new file mode 100644 index 0000000..156caee --- /dev/null +++ b/llmsql/_cli/llmsql_cli.py @@ -0,0 +1,62 @@ +import argparse +import textwrap + +from llmsql._cli.inference import Inference + + +class ParserCLI: + """Main CLI parser that manages all subcommands.""" + + def __init__(self) -> None: + self._parser = argparse.ArgumentParser( + prog="llmsql", + description="LLMSQL CLI", + epilog=textwrap.dedent(""" + Examples: + + # 1️⃣ Transformers backend + llmsql inference transformers \ + --model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \ + --output-file outputs/preds_transformers.jsonl \ + --batch-size 8 \ + --num-fewshots 5 + + # 2️⃣ vLLM backend + llmsql inference vllm \ + --model-name Qwen/Qwen2.5-1.5B-Instruct \ + --output-file outputs/preds_vllm.jsonl \ + --batch-size 8 \ + --num-fewshots 5 + + # 3️⃣ Transformers with model kwargs + llmsql inference transformers \ + --model-or-model-name-or-path meta-llama/Llama-3-8b-instruct \ + --model-kwargs '{"attn_implementation": "flash_attention_2"}' + + # 4️⃣ vLLM with LLM kwargs + llmsql inference vllm \ + --model-name mistralai/Mixtral-8x7B-Instruct-v0.1 \ + --llm-kwargs '{"max_model_len": 4096}' + + Visit https://github.com/LLMSQL/llmsql-benchmark for more + """), + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + self._parser.set_defaults(func=lambda args: self._parser.print_help()) + + self._subparsers = self._parser.add_subparsers( + dest="command", + metavar="COMMAND", + required=True, + ) + + Inference(self._subparsers) + + def parse_args(self) -> argparse.Namespace: + """Parse CLI arguments.""" + return self._parser.parse_args() + + def execute(self, args: argparse.Namespace) -> None: + """Execute selected command.""" + args.func(args) diff --git a/llmsql/_cli/subparsers.py b/llmsql/_cli/subparsers.py new file mode 100644 index 0000000..d69f75d --- /dev/null +++ b/llmsql/_cli/subparsers.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod +import argparse +from typing import Any, TypeVar + +T = TypeVar("T", bound="SubCommand") + + +class SubCommand(ABC): + """Base class for all subcommands.""" + + @abstractmethod + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Subclasses must implement their own initializer.""" + pass + + @classmethod + def create(cls: type[T], subparsers: argparse._SubParsersAction) -> T: + """Factory method to create and register a command instance.""" + return cls(subparsers) + + @abstractmethod + def _add_args(self) -> None: + """Add arguments specific to this subcommand.""" + pass diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py new file mode 100644 index 0000000..4f2d543 --- /dev/null +++ b/tests/cli/test_cli.py @@ -0,0 +1,140 @@ +import json +import sys +from unittest.mock import AsyncMock + +import pytest + +from llmsql._cli.llmsql_cli import ParserCLI + + +@pytest.mark.asyncio +async def test_transformers_backend_called(monkeypatch): + """ + Ensure transformers backend is correctly invoked. + """ + # Mock backend function + mock_inference = AsyncMock(return_value=[]) + + monkeypatch.setattr( + "llmsql.inference_transformers", + mock_inference, + ) + + test_args = [ + "llmsql", + "inference", + "transformers", + "--model-or-model-name-or-path", + "Qwen/Qwen2.5-1.5B-Instruct", + "--temperature", + "0.9", + "--generation-kwargs", + json.dumps({"top_p": 0.9}), + ] + + monkeypatch.setattr(sys, "argv", test_args) + + cli = ParserCLI() + args = cli.parse_args() + + cli.execute(args) + + # Assert backend was called + mock_inference.assert_called_once() + + call_kwargs = mock_inference.call_args.kwargs + assert call_kwargs["model_or_model_name_or_path"] == "Qwen/Qwen2.5-1.5B-Instruct" + assert call_kwargs["temperature"] == 0.9 + assert call_kwargs["generation_kwargs"]["top_p"] == 0.9 + + +@pytest.mark.asyncio +async def test_vllm_backend_called(monkeypatch): + """ + Ensure vLLM backend is correctly invoked. + """ + mock_inference = AsyncMock(return_value=[]) + + monkeypatch.setattr( + "llmsql.inference_vllm", + mock_inference, + ) + + test_args = [ + "llmsql", + "inference", + "vllm", + "--model-name", + "mistralai/Mixtral-8x7B-Instruct-v0.1", + "--tensor-parallel-size", + "2", + ] + + monkeypatch.setattr(sys, "argv", test_args) + + cli = ParserCLI() + args = cli.parse_args() + + cli.execute(args) + + mock_inference.assert_called_once() + + call_kwargs = mock_inference.call_args.kwargs + assert call_kwargs["model_name"] == "mistralai/Mixtral-8x7B-Instruct-v0.1" + assert call_kwargs["tensor_parallel_size"] == 2 + + +@pytest.mark.asyncio +async def test_missing_backend_errors(monkeypatch): + """ + Ensure missing backend fails. + """ + test_args = ["llmsql", "inference"] + + monkeypatch.setattr(sys, "argv", test_args) + + cli = ParserCLI() + + with pytest.raises(SystemExit): + cli.parse_args() + + +@pytest.mark.asyncio +async def test_invalid_json_kwargs(monkeypatch): + """ + Invalid JSON should raise argparse error. + """ + test_args = [ + "llmsql", + "inference", + "transformers", + "--model-or-model-name-or-path", + "test-model", + "--generation-kwargs", + "{invalid_json}", + ] + + monkeypatch.setattr(sys, "argv", test_args) + + cli = ParserCLI() + + with pytest.raises(SystemExit): + cli.parse_args() + + +@pytest.mark.asyncio +async def test_help_shows_without_crashing(monkeypatch, capsys): + """ + Running with no args should print help. + """ + test_args = ["llmsql"] + + monkeypatch.setattr(sys, "argv", test_args) + + cli = ParserCLI() + + with pytest.raises(SystemExit): + cli.parse_args() + + captured = capsys.readouterr() + assert "usage:" in captured.err.lower() or "usage:" in captured.out.lower() diff --git a/tests/conftest.py b/tests/conftest.py index 9125681..1f60dca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import glob import json import os from pathlib import Path @@ -9,6 +10,14 @@ import llmsql.inference.inference_vllm as inference_vllm +@pytest.fixture(scope="session", autouse=True) +def cleanup_evaluation_results(): + """Remove evaluation_results* files produced during tests.""" + yield + for path in glob.glob("evaluation_results*"): + os.remove(path) + + @pytest.fixture def temp_dir(tmp_path): return tmp_path @@ -104,7 +113,9 @@ def mock_utils(mocker, tmp_path): # download files mocker.patch( "llmsql.evaluation.evaluate.download_benchmark_file", - side_effect=lambda repo_id, filename, local_dir: str(Path(local_dir) / filename), + side_effect=lambda repo_id, filename, local_dir: str( + Path(local_dir) / filename + ), ) # report writer diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..465c932 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,22 @@ +# tests/test_main.py + +import llmsql.__main__ as main_module + + +def test_main_calls_parser_and_execute(monkeypatch): + called = {} + + class DummyParser: + def parse_args(self): + called["parse_args"] = True + return "parsed-args" + + def execute(self, args): + called["execute"] = args + + monkeypatch.setattr(main_module, "ParserCLI", DummyParser) + + main_module.main() + + assert called["parse_args"] is True + assert called["execute"] == "parsed-args"