diff --git a/llmsql/inference/inference_vllm.py b/llmsql/inference/inference_vllm.py index f661c6b..e813646 100644 --- a/llmsql/inference/inference_vllm.py +++ b/llmsql/inference/inference_vllm.py @@ -23,6 +23,7 @@ max_new_tokens=256, temperature=0.7, tensor_parallel_size=1, + lora_path="path/to/lora" ) Notes @@ -47,6 +48,7 @@ from dotenv import load_dotenv from tqdm import tqdm from vllm import LLM, SamplingParams +from vllm.lora.request import LoRARequest from llmsql.config.config import ( DEFAULT_LLMSQL_VERSION, @@ -74,6 +76,8 @@ def inference_vllm( hf_token: str | None = None, llm_kwargs: dict[str, Any] | None = None, use_chat_template: bool = True, + # === LoRA Parameters === + lora_config: dict[str, Any] | None = None, # new optional dict # === Generation Parameters === max_new_tokens: int = 256, temperature: float = 1.0, @@ -105,6 +109,16 @@ def inference_vllm( 'trust_remote_code' are handled separately and will override values here. + lora_config: Optional dict with LoRA parameters: + - lora_path: Path to the pretrained LoRA adapter (required if enable_lora) + - lora_name: Logical name for the LoRA adapter + - lora_scale: Scaling factor for LoRA weights + - max_lora_rank: Maximum LoRA rank supported by vLLM + LoRA usage rules: + - If `lora_config` is provided, `enable_lora` must be True in `llm_kwargs`. + - If `enable_lora` is True, a valid `lora_config` must be provided. + - Otherwise, an exception is raised to prevent inconsistent configuration. + # Generation: max_new_tokens: Maximum tokens to generate per sequence. temperature: Sampling temperature (0.0 = greedy). @@ -126,6 +140,8 @@ def inference_vllm( the first N samples. If a float between 0.0 and 1.0, evaluates the first X*100% of samples. If None, evaluates all samples (default). + + Returns: List of dicts containing `question_id` and generated `completion`. """ @@ -167,19 +183,41 @@ def inference_vllm( ) questions = questions[:limit] + # --- Validate LoRA usage --- + enable_lora = llm_kwargs.get("enable_lora", False) + if lora_config is not None and not enable_lora: + raise ValueError( + "LoRA config provided but `enable_lora` is not True in llm_kwargs." + ) + if enable_lora and lora_config is None: + raise ValueError("`enable_lora` is True but no `lora_config` was provided.") + if lora_config is not None and not enable_lora: + raise ValueError( + "`lora_config` provided but `enable_lora` is not True in llm_kwargs." + ) + # --- init model --- llm_init_args = { "model": model_name, "tokenizer": model_name, "tensor_parallel_size": tensor_parallel_size, "trust_remote_code": trust_remote_code, - **llm_kwargs, # User kwargs come first, but explicit params above will override + **llm_kwargs, # user overrides } log.info(f"Loading vLLM model '{model_name}' (tp={tensor_parallel_size})...") - llm = LLM(**llm_init_args) + # --- LoRA request --- + lora_request = None + if enable_lora and lora_config is not None: + log.info(f"Loading LoRA adapter from {lora_config['lora_path']}") + lora_request = LoRARequest( + lora_name=lora_config["lora_name"], + lora_path=lora_config["lora_path"], + scaling=lora_config["lora_scale"], + ) + tokenizer = llm.get_tokenizer() if use_chat_template: use_chat_template = getattr(tokenizer, "chat_template", None) # type: ignore @@ -228,7 +266,11 @@ def inference_vllm( prompts.append(final_prompt) - outputs = llm.generate(prompts, sampling_params) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=lora_request, + ) batch_results: list[dict[str, str]] = [] for q, out in zip(batch, outputs, strict=False): diff --git a/tests/inference/test_limit_argument.py b/tests/inference/test_limit_argument.py index e2d2e09..63cc20a 100644 --- a/tests/inference/test_limit_argument.py +++ b/tests/inference/test_limit_argument.py @@ -46,7 +46,7 @@ def _patch_common_vllm(monkeypatch, tmp_path): ) fake_llm = MagicMock() - fake_llm.generate.side_effect = lambda prompts, _params: [ + fake_llm.generate.side_effect = lambda prompts, *a, **kw: [ MagicMock(outputs=[MagicMock(text=f"SELECT {i}")]) for i in range(len(prompts)) ] monkeypatch.setattr(vllm_mod, "LLM", lambda *a, **kw: fake_llm)