From f850a89b06b8fa56f47f6ae3e31ad1179330a16e Mon Sep 17 00:00:00 2001 From: Viknov Date: Thu, 8 Jan 2026 19:12:25 +0100 Subject: [PATCH 1/2] Added optional LoRA adapter support for vLLM inference. --- llmsql/inference/inference_vllm.py | 32 +++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/llmsql/inference/inference_vllm.py b/llmsql/inference/inference_vllm.py index bbeb2a9..ab8e819 100644 --- a/llmsql/inference/inference_vllm.py +++ b/llmsql/inference/inference_vllm.py @@ -22,6 +22,7 @@ max_new_tokens=256, temperature=0.7, tensor_parallel_size=1, + lora_path="path/to/lora" ) Notes @@ -46,6 +47,7 @@ from dotenv import load_dotenv from tqdm import tqdm from vllm import LLM, SamplingParams +from vllm.lora.request import LoRARequest from llmsql.config.config import DEFAULT_WORKDIR_PATH from llmsql.loggers.logging_config import log @@ -69,6 +71,11 @@ def inference_vllm( hf_token: str | None = None, llm_kwargs: dict[str, Any] | None = None, use_chat_template: bool = True, + # === LoRA Parameters === + lora_path: str | None = None, + lora_name: str = "default", + lora_scale: float = 1.0, + max_lora_rank: int = 64, # === Generation Parameters === max_new_tokens: int = 256, temperature: float = 1.0, @@ -98,6 +105,12 @@ def inference_vllm( 'trust_remote_code' are handled separately and will override values here. + # LoRA: + lora_path: Path to pretrained LoRA adapter (optional). + lora_name: Logical name of the LoRA adapter. + lora_scale: Scaling factor for LoRA weights. + max_lora_rank: Maximum LoRA rank supported by vLLM. + # Generation: max_new_tokens: Maximum tokens to generate per sequence. temperature: Sampling temperature (0.0 = greedy). @@ -115,6 +128,8 @@ def inference_vllm( batch_size: Number of questions per generation batch. seed: Random seed for reproducibility. + + Returns: List of dicts containing `question_id` and generated `completion`. """ @@ -141,6 +156,8 @@ def inference_vllm( "tokenizer": model_name, "tensor_parallel_size": tensor_parallel_size, "trust_remote_code": trust_remote_code, + "enable_lora": lora_path is not None, + "max_lora_rank": max_lora_rank, **llm_kwargs, # User kwargs come first, but explicit params above will override } @@ -148,6 +165,15 @@ def inference_vllm( llm = LLM(**llm_init_args) + lora_request = None + if lora_path is not None: + log.info(f"Loading LoRA adapter from {lora_path}") + lora_request = LoRARequest( + lora_name=lora_name, + lora_path=lora_path, + scaling=lora_scale, + ) + tokenizer = llm.get_tokenizer() if use_chat_template: use_chat_template = getattr(tokenizer, "chat_template", None) # type: ignore @@ -196,7 +222,11 @@ def inference_vllm( prompts.append(final_prompt) - outputs = llm.generate(prompts, sampling_params) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=lora_request, + ) batch_results: list[dict[str, str]] = [] for q, out in zip(batch, outputs, strict=False): From 9ec17e50fc911611b1535290b47c55afb94b9c37 Mon Sep 17 00:00:00 2001 From: Dzmitry Pihulski Date: Mon, 23 Feb 2026 16:18:59 +0100 Subject: [PATCH 2/2] add: Lora Config added as one argument --- llmsql/inference/inference_vllm.py | 48 ++++++++++++++++---------- tests/inference/test_limit_argument.py | 2 +- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/llmsql/inference/inference_vllm.py b/llmsql/inference/inference_vllm.py index 30126b6..e813646 100644 --- a/llmsql/inference/inference_vllm.py +++ b/llmsql/inference/inference_vllm.py @@ -77,10 +77,7 @@ def inference_vllm( llm_kwargs: dict[str, Any] | None = None, use_chat_template: bool = True, # === LoRA Parameters === - lora_path: str | None = None, - lora_name: str = "default", - lora_scale: float = 1.0, - max_lora_rank: int = 64, + lora_config: dict[str, Any] | None = None, # new optional dict # === Generation Parameters === max_new_tokens: int = 256, temperature: float = 1.0, @@ -112,11 +109,15 @@ def inference_vllm( 'trust_remote_code' are handled separately and will override values here. - # LoRA: - lora_path: Path to pretrained LoRA adapter (optional). - lora_name: Logical name of the LoRA adapter. - lora_scale: Scaling factor for LoRA weights. - max_lora_rank: Maximum LoRA rank supported by vLLM. + lora_config: Optional dict with LoRA parameters: + - lora_path: Path to the pretrained LoRA adapter (required if enable_lora) + - lora_name: Logical name for the LoRA adapter + - lora_scale: Scaling factor for LoRA weights + - max_lora_rank: Maximum LoRA rank supported by vLLM + LoRA usage rules: + - If `lora_config` is provided, `enable_lora` must be True in `llm_kwargs`. + - If `enable_lora` is True, a valid `lora_config` must be provided. + - Otherwise, an exception is raised to prevent inconsistent configuration. # Generation: max_new_tokens: Maximum tokens to generate per sequence. @@ -182,28 +183,39 @@ def inference_vllm( ) questions = questions[:limit] + # --- Validate LoRA usage --- + enable_lora = llm_kwargs.get("enable_lora", False) + if lora_config is not None and not enable_lora: + raise ValueError( + "LoRA config provided but `enable_lora` is not True in llm_kwargs." + ) + if enable_lora and lora_config is None: + raise ValueError("`enable_lora` is True but no `lora_config` was provided.") + if lora_config is not None and not enable_lora: + raise ValueError( + "`lora_config` provided but `enable_lora` is not True in llm_kwargs." + ) + # --- init model --- llm_init_args = { "model": model_name, "tokenizer": model_name, "tensor_parallel_size": tensor_parallel_size, "trust_remote_code": trust_remote_code, - "enable_lora": lora_path is not None, - "max_lora_rank": max_lora_rank, - **llm_kwargs, # User kwargs come first, but explicit params above will override + **llm_kwargs, # user overrides } log.info(f"Loading vLLM model '{model_name}' (tp={tensor_parallel_size})...") - llm = LLM(**llm_init_args) + # --- LoRA request --- lora_request = None - if lora_path is not None: - log.info(f"Loading LoRA adapter from {lora_path}") + if enable_lora and lora_config is not None: + log.info(f"Loading LoRA adapter from {lora_config['lora_path']}") lora_request = LoRARequest( - lora_name=lora_name, - lora_path=lora_path, - scaling=lora_scale, + lora_name=lora_config["lora_name"], + lora_path=lora_config["lora_path"], + scaling=lora_config["lora_scale"], ) tokenizer = llm.get_tokenizer() diff --git a/tests/inference/test_limit_argument.py b/tests/inference/test_limit_argument.py index e2d2e09..63cc20a 100644 --- a/tests/inference/test_limit_argument.py +++ b/tests/inference/test_limit_argument.py @@ -46,7 +46,7 @@ def _patch_common_vllm(monkeypatch, tmp_path): ) fake_llm = MagicMock() - fake_llm.generate.side_effect = lambda prompts, _params: [ + fake_llm.generate.side_effect = lambda prompts, *a, **kw: [ MagicMock(outputs=[MagicMock(text=f"SELECT {i}")]) for i in range(len(prompts)) ] monkeypatch.setattr(vllm_mod, "LLM", lambda *a, **kw: fake_llm)